How to use BackPACK

If you haven’t already installed it,

pip install backpack-for-pytorch

To use BackPACK with your setup, you will need to extend the model and the loss function and register the extension you want to use with backpack before calling backward().

Extending the model and loss function

The extend(torch.nn.Module) function tells BackPACK what part of the computation graph needs to be tracked. If your model is a torch.nn.Sequential and you use one of the torch.nn loss functions;

import torch
from backpack import extend
from utils import load_data

X, y = load_data()

model = torch.nn.Sequential(
    torch.nn.Linear(784, 64),
    torch.nn.ReLU(),
    torch.nn.Linear(64, 10)
)
lossfunc = torch.nn.CrossEntropyLoss()

model = extend(model)
lossfunc = extend(lossfunc)

Calling the extension

To activate an extension, call backward() inside a with backpack(extension): block;

from backpack import backpack
from backpack.extensions import KFAC

loss = lossfunc(model(X), y)

with backpack(KFAC()):
    loss.backward()

for param in model.parameters():
    print(param.grad)
    print(param.kfac)

See Available Extensions for other quantities, and the Supported models.


backpack.extend(module: torch.nn.modules.module.Module, debug=False)

Extends a module to make it BackPACK-ready.

If the module has children, e.g. for a torch.nn.Sequential, they will also be extended.

Parameters
  • module (torch.nn.Module) – The module to extend.

  • debug (bool, optional) – Print debug messages during the extension. Default: False.

Returns

Extended module.

Return type

torch.nn.Module

backpack.backpack(*exts: backpack.extensions.backprop_extension.BackpropExtension, extension_hook=None, debug=False)

Activate BackPACK extensions.

Enables the BackPACK extensions passed as arguments in the backward calls inside the current with block.

Parameters
  • exts ([BackpropExtension]) – Extensions to activate in the backward pass.

  • extension_hook (function, optional) –

    Function called on each module after all BackPACK extensions have run. Takes a torch.nn.Module and returns None. Default: None (no operation will be formed).

    Can be used to reduce memory overhead if the goal is to compute transformations of BackPACK quantities. Information can be compacted during a backward pass and obsolete tensors be freed manually (del).

    Note

    If the callable iterates over the module.parameters(), the same parameter may be seen multiple times across calls. This happens if the parameters are part of multiple modules. For example, the parameters of a torch.nn.Linear module in model = torch.nn.Sequential(torch.nn.Linear(...)) are part of both the Linear and the Sequential.

  • debug (bool, optional) – Print debug messages during the backward pass. Default: False.

backpack.disable()

Entirely disable BackPACK, including storage of input and output.

To compute the additional quantities, BackPACK needs to know the input and output of the modules in the computation graph. It saves those by default. disable tells BackPACK to _not_ save this information during the forward.

This can be useful if you only want a gradient with pytorch on a module that is extended with BackPACK and need to avoid memory overhead. If you do not need any gradient, use the torch.no_grad context instead.

This context is not the exact opposite of the backpack context. The backpack context enables specific extensions during a backward. This context disables storing input/output information during a forward.

Note

with backpack(...) in a with disable() context will fail even if the forward pass is carried out in with backpack(...).