Supported models¶
BackPACK expects models to be sequences of PyTorch NN modules. For example,
model = torch.nn.Sequential(
torch.nn.Linear(784, 64),
torch.nn.ReLU(),
torch.nn.Linear(64, 10)
)
If you overwrite any forward()
function (for example in ResNets
and RNNs), the additional backward pass to compute second-order quantities will
break. You can ask BackPACK to inspect the graph and try converting it
into a compatible structure by setting use_converter=True
in
extend
.
This page lists the layers currently supported by BackPACK.
For first-order extensions¶
BackPACK can extract more information about the gradient with respect to the parameters of the following layers;
torch.nn.ConvTranspose1d
,torch.nn.ConvTranspose2d
,torch.nn.ConvTranspose3d
torch.nn.BatchNorm1d
(evaluation mode),torch.nn.BatchNorm2d
(evaluation mode),torch.nn.BatchNorm3d
(evaluation mode)
Some layers (like torch.nn.BatchNormNd
in training mode) mix samples and
lead to ill-defined first-order quantities.
For second-order extensions¶
BackPACK needs to know how to backpropagate additional information for second-order quantities. This is implemented for:
Parametrized layers |
|
|
|
|
|
Loss functions |
|
Layers without parameters |
|
|
|
|
|