How to use BackPACK

If you haven’t already installed it,

pip install backpack-for-pytorch

To use BackPACK with your setup, you will need to extend the model and the loss function and register the extension you want to use with backpack before calling backward().

Extending the model and loss function

The extend(torch.nn.Module) function tells BackPACK what part of the computation graph needs to be tracked. If your model is a torch.nn.Sequential and you use one of the torch.nn loss functions;

import torch
from backpack import extend
from utils import load_data

X, y = load_data()

model = torch.nn.Sequential(
    torch.nn.Linear(784, 64),
    torch.nn.ReLU(),
    torch.nn.Linear(64, 10)
)
lossfunc = torch.nn.CrossEntropyLoss()

model = extend(model)
lossfunc = extend(lossfunc)

Calling the extension

To activate an extension, call backward() inside a with backpack(extension): block;

from backpack import backpack
from backpack.extensions import KFAC

loss = lossfunc(model(X), y)

with backpack(KFAC()):
    loss.backward()

for param in model.parameters():
    print(param.grad)
    print(param.kfac)

See Available Extensions for other quantities, and the Supported models.


backpack.extend(module: torch.nn.modules.module.Module, debug=False)

Extends the module to make it backPACK-ready.

If the module has children, e.g. for a torch.nn.Sequential, they will also be extended.

Args:
module: torch.nn.Module
The module to extend
debug: Bool, optional (default: False)
If true, will print debug messages during the extension.
backpack.backpack(*exts, extension_hook=None, debug=False)

Activates Backpack Extensions.

Activates the BackPACK extensions passed as arguments for the backward calls in the current with block.