How to use BackPACK¶
If you haven’t already installed it,
pip install backpack-for-pytorch
To use BackPACK with your setup,
you will need to extend
the model and the loss function
and register the extension you want to use with backpack
before calling backward()
.
Extending the model and loss function¶
The extend(torch.nn.Module)
function
tells BackPACK what part of the computation graph needs to be tracked.
If your model is a torch.nn.Sequential
and you use one of the
torch.nn
loss functions;
import torch
from backpack import extend
from utils import load_data
X, y = load_data()
model = torch.nn.Sequential(
torch.nn.Linear(784, 64),
torch.nn.ReLU(),
torch.nn.Linear(64, 10)
)
lossfunc = torch.nn.CrossEntropyLoss()
model = extend(model)
lossfunc = extend(lossfunc)
Calling the extension¶
To activate an extension, call backward()
inside a
with backpack(extension):
block;
from backpack import backpack
from backpack.extensions import KFAC
loss = lossfunc(model(X), y)
with backpack(KFAC()):
loss.backward()
for param in model.parameters():
print(param.grad)
print(param.kfac)
See Available Extensions for other quantities, and the Supported models.
-
backpack.
extend
(module: torch.nn.modules.module.Module, debug=False)¶ Extends the
module
to make it backPACK-ready.If the
module
has children, e.g. for atorch.nn.Sequential
, they will also be extended.- Args:
- module: torch.nn.Module
- The module to extend
- debug: Bool, optional (default: False)
- If true, will print debug messages during the extension.
-
backpack.
backpack
(*exts, debug=False)¶ Activates Backpack Extensions.
Activates the BackPACK extensions passed as arguments for the
backward
calls in the currentwith
block.