How to use BackPACK¶
If you haven’t already installed it,
pip install backpack-for-pytorch
To use BackPACK with your setup,
you will need to extend
the model and the loss function
and register the extension you want to use with backpack
before calling backward()
.
Extending the model and loss function¶
The extend(torch.nn.Module)
function
tells BackPACK what part of the computation graph needs to be tracked.
If your model is a torch.nn.Sequential
and you use one of the
torch.nn
loss functions;
import torch
from backpack import extend
from utils import load_data
X, y = load_data()
model = torch.nn.Sequential(
torch.nn.Linear(784, 64),
torch.nn.ReLU(),
torch.nn.Linear(64, 10)
)
lossfunc = torch.nn.CrossEntropyLoss()
model = extend(model)
lossfunc = extend(lossfunc)
Calling the extension¶
To activate an extension, call backward()
inside a
with backpack(extension):
block;
from backpack import backpack
from backpack.extensions import KFAC
loss = lossfunc(model(X), y)
with backpack(KFAC()):
loss.backward()
for param in model.parameters():
print(param.grad)
print(param.kfac)
See Available Extensions for other quantities, and the Supported models.
- backpack.extend(module: torch.nn.modules.module.Module, debug: bool = False, use_converter: bool = False) torch.nn.modules.module.Module ¶
Recursively extend a
module
to make it BackPACK-ready.Modules that do not represent an operation in the computation graph (for instance containers like
Sequential
) will not explicitly be extended.- Parameters
module – The module to extend.
debug – Print debug messages during the extension. Default:
False
.use_converter – Try converting the module to a BackPACK-compatible network. The converter might alter the model, e.g. order of parameters. Default:
False
.
- Returns
Extended module.
- class backpack.backpack(*exts: backpack.extensions.backprop_extension.BackpropExtension, extension_hook: Optional[Callable[[torch.nn.modules.module.Module], None]] = None, debug: bool = False, retain_graph: bool = False)¶
Context manager to activate BackPACK extensions.
- __init__(*exts: backpack.extensions.backprop_extension.BackpropExtension, extension_hook: Optional[Callable[[torch.nn.modules.module.Module], None]] = None, debug: bool = False, retain_graph: bool = False)¶
Activate BackPACK extensions.
Enables the BackPACK extensions passed as arguments in the
backward
calls inside the currentwith
block.- Parameters
exts – Extensions to activate in the backward pass.
extension_hook – Function called on each module after all BackPACK extensions have run. Takes a
torch.nn.Module
and returnsNone
. Default:None
(no operation will be performed).debug – Print debug messages during the backward pass. Default:
False
.retain_graph – Determines whether BackPack IO should be kept for additional backward passes. Should have same value as the argument
retain_graph
inbackward()
. Default:False
.
Note
extension_hook can be used to reduce memory overhead if the goal is to compute transformations of BackPACK quantities. Information can be compacted during a backward pass and obsolete tensors be freed manually (
del
).- Raises
ValueError – if extensions are not valid
- backpack.disable()¶
Entirely disable BackPACK, including storage of input and output.
To compute the additional quantities, BackPACK needs to know the input and output of the modules in the computation graph. It saves those by default.
disable
tells BackPACK to _not_ save this information during the forward.This can be useful if you only want a gradient with pytorch on a module that is
extended
with BackPACK and need to avoid memory overhead. If you do not need any gradient, use thetorch.no_grad
context instead.This context is not the exact opposite of the
backpack
context. Thebackpack
context enables specific extensions during a backward. This context disables storing input/output information during a forward.Note
with backpack(...)
in awith disable()
context will fail even if the forward pass is carried out inwith backpack(...)
.