Note
Go to the end to download the full example code.
Residual networks
There are three different approaches to using BackPACK with ResNets.
Custom ResNet: (Only works for first-order extensions) Write your own model by defining its forward pass. Trainable parameters must be in modules known to BackPACK (e.g.
torch.nn.Conv2d
,torch.nn.Linear
).Custom ResNet with BackPACK custom modules: (Works for first- and second- order extensions) Build your ResNet with custom modules provided by BackPACK without overwriting the forward pass. This approach is useful if you want to understand how BackPACK handles ResNets, or if you think building a container module that implicitly defines the forward pass is more elegant than coding up a forward pass.
Any ResNet with BackPACK’s converter: (Works for first- and second-order extensions) Convert your model into a BackPACK-compatible architecture.
Note
ResNets are still an experimental feature. Always double-check your results, as done in this example! Open an issue if you encounter a bug to help us improve the support.
Not all extensions support ResNets (yet). Please create a feature request in the repository if the extension you need is not supported.
Let’s get the imports out of the way.
from torch import (
allclose,
cat,
cuda,
device,
int32,
linspace,
manual_seed,
rand,
rand_like,
)
from torch.nn import (
Conv2d,
CrossEntropyLoss,
Flatten,
Identity,
Linear,
Module,
MSELoss,
ReLU,
Sequential,
)
from torch.nn.functional import cross_entropy, relu
from torchvision.models import resnet18
from backpack import backpack, extend
from backpack.custom_module.branching import Parallel, SumModule
from backpack.custom_module.graph_utils import BackpackTracer
from backpack.extensions import BatchGrad, DiagGGNExact
from backpack.utils.examples import autograd_diag_ggn_exact, load_one_batch_mnist
manual_seed(0)
DEVICE = device("cuda:0" if cuda.is_available() else "cpu")
x, y = load_one_batch_mnist(batch_size=32)
x, y = x.to(DEVICE), y.to(DEVICE)
Custom ResNet
We can build a ResNet by extending torch.nn.Module
.
As long as the layers with parameters (torch.nn.Conv2d
and torch.nn.Linear
) are nn
modules, BackPACK can extend them,
and this is all that is needed for first-order extensions.
We can rewrite the forward
to implement the residual connection,
and extend()
the resulting model.
Note
Using in-place operations is not compatible with PyTorch’s
torch.nn.Module.register_full_backward_hook()
. Therefore,
always use x = x + residual
instead of x += residual
.
class MyFirstResNet(Module):
def __init__(self, C_in=1, C_hid=5, input_dim=(28, 28), output_dim=10):
"""Instantiate submodules that are used in the forward pass."""
super().__init__()
self.conv1 = Conv2d(C_in, C_hid, kernel_size=3, stride=1, padding=1)
self.conv2 = Conv2d(C_hid, C_hid, kernel_size=3, stride=1, padding=1)
self.linear1 = Linear(input_dim[0] * input_dim[1] * C_hid, output_dim)
if C_in == C_hid:
self.shortcut = Identity()
else:
self.shortcut = Conv2d(C_in, C_hid, kernel_size=1, stride=1)
def forward(self, x):
"""Manual implementation of the forward pass."""
residual = self.shortcut(x)
x = self.conv2(relu(self.conv1(x)))
x = x + residual # don't use: x += residual
x = x.flatten(start_dim=1)
x = self.linear1(x)
return x
model = extend(MyFirstResNet()).to(DEVICE)
The loss does not need to be extended in this case either, as it does not have model parameters and BackPACK does not need to know about it for first-order extensions. This also means you can use any custom loss function.
Using BatchGrad
in a
with backpack(...)
block,
we can access the individual gradients for each sample.
loss = cross_entropy(model(x), y, reduction="sum")
with backpack(BatchGrad()):
loss.backward()
for name, parameter in model.named_parameters():
print(f"{name:>20}'s grad_batch shape: {parameter.grad_batch.shape}")
/home/docs/checkouts/readthedocs.org/user_builds/backpack/envs/master/lib/python3.9/site-packages/backpack/extensions/backprop_extension.py:107: UserWarning: Extension saving to grad_batch does not have an extension for Module <class '__main__.MyFirstResNet'> although the module has parameters
warnings.warn(
conv1.weight's grad_batch shape: torch.Size([32, 5, 1, 3, 3])
conv1.bias's grad_batch shape: torch.Size([32, 5])
conv2.weight's grad_batch shape: torch.Size([32, 5, 5, 3, 3])
conv2.bias's grad_batch shape: torch.Size([32, 5])
linear1.weight's grad_batch shape: torch.Size([32, 10, 3920])
linear1.bias's grad_batch shape: torch.Size([32, 10])
shortcut.weight's grad_batch shape: torch.Size([32, 5, 1, 1, 1])
shortcut.bias's grad_batch shape: torch.Size([32, 5])
To check that everything works, let’s compute one individual gradient with PyTorch (using a single sample in a forward and backward pass) and compare it with the one computed by BackPACK.
sample_to_check = 1
x_to_check = x[[sample_to_check]]
y_to_check = y[[sample_to_check]]
model.zero_grad()
loss = cross_entropy(model(x_to_check), y_to_check)
loss.backward()
print("Do the individual gradients match?")
for name, parameter in model.named_parameters():
match = allclose(parameter.grad_batch[sample_to_check], parameter.grad, atol=1e-6)
print(f"{name:>20}: {match}")
if not match:
raise AssertionError("Individual gradients don't match!")
Do the individual gradients match?
conv1.weight: True
conv1.bias: True
conv2.weight: True
conv2.bias: True
linear1.weight: True
linear1.bias: True
shortcut.weight: True
shortcut.bias: True
Custom ResNet with BackPACK custom modules
Second-order extensions only work if every node in the computation graph is an
nn
module that can be extended by BackPACK. The above ResNet class
MyFirstResNet
does not satisfy these conditions, because
it implements the skip connection via torch.add()
while overwriting the
forward()
method.
To build ResNets without overwriting the forward pass, BackPACK offers custom modules:
Parallel
is similar totorch.nn.Sequential
, but implements a container for a parallel sequence of modules (followed by an aggregation module), rather than a sequential one.SumModule
is the module that takes the role oftorch.add()
in the previous example. It sums up multiple inputs. We will use it to merge the skip connection.
With the above modules, we can build a simple ResNet as a container that implicitly defines the forward pass:
C_in = 1
C_hid = 2
input_dim = (28, 28)
output_dim = 10
model = Sequential(
Conv2d(C_in, C_hid, kernel_size=3, stride=1, padding=1),
ReLU(),
Parallel( # skip connection with ReLU-activated convolution
Identity(),
Sequential(
Conv2d(C_hid, C_hid, kernel_size=3, stride=1, padding=1),
ReLU(),
),
merge_module=SumModule(),
),
Flatten(),
Linear(input_dim[0] * input_dim[1] * C_hid, output_dim),
)
model = extend(model.to(DEVICE))
loss_function = extend(CrossEntropyLoss(reduction="mean")).to(DEVICE)
This ResNets supports BackPACK’s second-order extensions:
loss = loss_function(model(x), y)
with backpack(DiagGGNExact()):
loss.backward()
for name, parameter in model.named_parameters():
print(f"{name}'s diag_ggn_exact: {parameter.diag_ggn_exact.shape}")
diag_ggn_exact_vector = cat([p.diag_ggn_exact.flatten() for p in model.parameters()])
0.weight's diag_ggn_exact: torch.Size([2, 1, 3, 3])
0.bias's diag_ggn_exact: torch.Size([2])
2.branch.1.0.weight's diag_ggn_exact: torch.Size([2, 2, 3, 3])
2.branch.1.0.bias's diag_ggn_exact: torch.Size([2])
4.weight's diag_ggn_exact: torch.Size([10, 1568])
4.bias's diag_ggn_exact: torch.Size([10])
Comparison with torch.autograd
:
Note
Computing the full GGN diagonal with PyTorch’s built-in automatic differentiation can be slow, depending on the number of parameters. To reduce run time, we only compare some elements of the diagonal.
num_params = sum(p.numel() for p in model.parameters())
num_to_compare = 10
idx_to_compare = linspace(0, num_params - 1, num_to_compare, device=DEVICE, dtype=int32)
diag_ggn_exact_to_compare = autograd_diag_ggn_exact(
x, y, model, loss_function, idx=idx_to_compare
)
print("Do the exact GGN diagonals match?")
for idx, element in zip(idx_to_compare, diag_ggn_exact_to_compare):
match = allclose(element, diag_ggn_exact_vector[idx], atol=1e-6)
print(f"Diagonal entry {idx:>6}: {match}")
if not match:
raise AssertionError("Exact GGN diagonals don't match!")
Do the exact GGN diagonals match?
Diagonal entry 0: True
Diagonal entry 1749: True
Diagonal entry 3499: True
Diagonal entry 5249: True
Diagonal entry 6998: True
Diagonal entry 8748: True
Diagonal entry 10498: True
Diagonal entry 12247: True
Diagonal entry 13997: True
Diagonal entry 15747: True
Any ResNet with BackPACK’s converter
If you are not building a ResNet through custom modules but for instance want to
use a prominent ResNet from torchvision.models
, BackPACK offers a converter.
It analyzes the model and tries to turn it into a compatible architecture. The result
is a torch.fx.GraphModule
that exclusively consists of modules.
Here, we demo the converter on resnet18
.
Note
resnet18
has to be in evaluation mode,
because it contains batch normalization layers that are not supported in train
mode by the second-order extension used in this example.
Let’s create the model, and convert it in the call to extend
:
loss_function = extend(MSELoss().to(DEVICE))
model = resnet18(num_classes=5).to(DEVICE).eval()
# use BackPACK's converter to extend the model (turned off by default)
model = extend(model, use_converter=True)
To get an understanding what happened, we can inspect the model’s graph with the following helper function:
def print_table(module: Module) -> None:
"""Prints a table of the module.
Args:
module: module to analyze
"""
graph = BackpackTracer().trace(module)
graph.print_tabular()
print_table(model)
opcode name target args kwargs
----------- --------------------- --------------------- ------------------------------------- --------
placeholder x x () {}
call_module conv1 conv1 (x,) {}
call_module bn1 bn1 (conv1,) {}
call_module relu relu (bn1,) {}
call_module maxpool maxpool (relu,) {}
call_module layer1_0_conv1 layer1.0.conv1 (maxpool,) {}
call_module layer1_0_bn1 layer1.0.bn1 (layer1_0_conv1,) {}
call_module layer1_0_relu0 layer1.0.relu0 (layer1_0_bn1,) {}
call_module layer1_0_conv2 layer1.0.conv2 (layer1_0_relu0,) {}
call_module layer1_0_bn2 layer1.0.bn2 (layer1_0_conv2,) {}
call_module sum_module0 sum_module0 (layer1_0_bn2, maxpool) {}
call_module layer1_0_relu1 layer1.0.relu1 (sum_module0,) {}
call_module layer1_1_conv1 layer1.1.conv1 (layer1_0_relu1,) {}
call_module layer1_1_bn1 layer1.1.bn1 (layer1_1_conv1,) {}
call_module layer1_1_relu0 layer1.1.relu0 (layer1_1_bn1,) {}
call_module layer1_1_conv2 layer1.1.conv2 (layer1_1_relu0,) {}
call_module layer1_1_bn2 layer1.1.bn2 (layer1_1_conv2,) {}
call_module sum_module1 sum_module1 (layer1_1_bn2, layer1_0_relu1) {}
call_module layer1_1_relu1 layer1.1.relu1 (sum_module1,) {}
call_module layer2_0_conv1 layer2.0.conv1 (layer1_1_relu1,) {}
call_module layer2_0_bn1 layer2.0.bn1 (layer2_0_conv1,) {}
call_module layer2_0_relu0 layer2.0.relu0 (layer2_0_bn1,) {}
call_module layer2_0_conv2 layer2.0.conv2 (layer2_0_relu0,) {}
call_module layer2_0_bn2 layer2.0.bn2 (layer2_0_conv2,) {}
call_module layer2_0_downsample_0 layer2.0.downsample.0 (layer1_1_relu1,) {}
call_module layer2_0_downsample_1 layer2.0.downsample.1 (layer2_0_downsample_0,) {}
call_module sum_module2 sum_module2 (layer2_0_bn2, layer2_0_downsample_1) {}
call_module layer2_0_relu1 layer2.0.relu1 (sum_module2,) {}
call_module layer2_1_conv1 layer2.1.conv1 (layer2_0_relu1,) {}
call_module layer2_1_bn1 layer2.1.bn1 (layer2_1_conv1,) {}
call_module layer2_1_relu0 layer2.1.relu0 (layer2_1_bn1,) {}
call_module layer2_1_conv2 layer2.1.conv2 (layer2_1_relu0,) {}
call_module layer2_1_bn2 layer2.1.bn2 (layer2_1_conv2,) {}
call_module sum_module3 sum_module3 (layer2_1_bn2, layer2_0_relu1) {}
call_module layer2_1_relu1 layer2.1.relu1 (sum_module3,) {}
call_module layer3_0_conv1 layer3.0.conv1 (layer2_1_relu1,) {}
call_module layer3_0_bn1 layer3.0.bn1 (layer3_0_conv1,) {}
call_module layer3_0_relu0 layer3.0.relu0 (layer3_0_bn1,) {}
call_module layer3_0_conv2 layer3.0.conv2 (layer3_0_relu0,) {}
call_module layer3_0_bn2 layer3.0.bn2 (layer3_0_conv2,) {}
call_module layer3_0_downsample_0 layer3.0.downsample.0 (layer2_1_relu1,) {}
call_module layer3_0_downsample_1 layer3.0.downsample.1 (layer3_0_downsample_0,) {}
call_module sum_module4 sum_module4 (layer3_0_bn2, layer3_0_downsample_1) {}
call_module layer3_0_relu1 layer3.0.relu1 (sum_module4,) {}
call_module layer3_1_conv1 layer3.1.conv1 (layer3_0_relu1,) {}
call_module layer3_1_bn1 layer3.1.bn1 (layer3_1_conv1,) {}
call_module layer3_1_relu0 layer3.1.relu0 (layer3_1_bn1,) {}
call_module layer3_1_conv2 layer3.1.conv2 (layer3_1_relu0,) {}
call_module layer3_1_bn2 layer3.1.bn2 (layer3_1_conv2,) {}
call_module sum_module5 sum_module5 (layer3_1_bn2, layer3_0_relu1) {}
call_module layer3_1_relu1 layer3.1.relu1 (sum_module5,) {}
call_module layer4_0_conv1 layer4.0.conv1 (layer3_1_relu1,) {}
call_module layer4_0_bn1 layer4.0.bn1 (layer4_0_conv1,) {}
call_module layer4_0_relu0 layer4.0.relu0 (layer4_0_bn1,) {}
call_module layer4_0_conv2 layer4.0.conv2 (layer4_0_relu0,) {}
call_module layer4_0_bn2 layer4.0.bn2 (layer4_0_conv2,) {}
call_module layer4_0_downsample_0 layer4.0.downsample.0 (layer3_1_relu1,) {}
call_module layer4_0_downsample_1 layer4.0.downsample.1 (layer4_0_downsample_0,) {}
call_module sum_module6 sum_module6 (layer4_0_bn2, layer4_0_downsample_1) {}
call_module layer4_0_relu1 layer4.0.relu1 (sum_module6,) {}
call_module layer4_1_conv1 layer4.1.conv1 (layer4_0_relu1,) {}
call_module layer4_1_bn1 layer4.1.bn1 (layer4_1_conv1,) {}
call_module layer4_1_relu0 layer4.1.relu0 (layer4_1_bn1,) {}
call_module layer4_1_conv2 layer4.1.conv2 (layer4_1_relu0,) {}
call_module layer4_1_bn2 layer4.1.bn2 (layer4_1_conv2,) {}
call_module sum_module7 sum_module7 (layer4_1_bn2, layer4_0_relu1) {}
call_module layer4_1_relu1 layer4.1.relu1 (sum_module7,) {}
call_module avgpool avgpool (layer4_1_relu1,) {}
call_module flatten0 flatten0 (avgpool,) {}
call_module fc fc (flatten0,) {}
output output output (fc,) {}
Admittedly, the converted resnet18
’s graph
is quite large. Note however that it fully consists of modules (indicated by
call_module
in the first table column) such that BackPACK’s hooks can
successfully backpropagate additional information for its second-order extensions
(first-order extensions work, too).
Let’s verify that second-order extensions are working:
x = rand(4, 3, 7, 7, device=DEVICE) # (128, 3, 224, 224)
output = model(x)
y = rand_like(output)
loss = loss_function(output, y)
with backpack(DiagGGNExact()):
loss.backward()
for name, parameter in model.named_parameters():
print(f"{name}'s diag_ggn_exact: {parameter.diag_ggn_exact.shape}")
diag_ggn_exact_vector = cat([p.diag_ggn_exact.flatten() for p in model.parameters()])
conv1.weight's diag_ggn_exact: torch.Size([64, 3, 7, 7])
bn1.weight's diag_ggn_exact: torch.Size([64])
bn1.bias's diag_ggn_exact: torch.Size([64])
layer1.0.conv1.weight's diag_ggn_exact: torch.Size([64, 64, 3, 3])
layer1.0.bn1.weight's diag_ggn_exact: torch.Size([64])
layer1.0.bn1.bias's diag_ggn_exact: torch.Size([64])
layer1.0.conv2.weight's diag_ggn_exact: torch.Size([64, 64, 3, 3])
layer1.0.bn2.weight's diag_ggn_exact: torch.Size([64])
layer1.0.bn2.bias's diag_ggn_exact: torch.Size([64])
layer1.1.conv1.weight's diag_ggn_exact: torch.Size([64, 64, 3, 3])
layer1.1.bn1.weight's diag_ggn_exact: torch.Size([64])
layer1.1.bn1.bias's diag_ggn_exact: torch.Size([64])
layer1.1.conv2.weight's diag_ggn_exact: torch.Size([64, 64, 3, 3])
layer1.1.bn2.weight's diag_ggn_exact: torch.Size([64])
layer1.1.bn2.bias's diag_ggn_exact: torch.Size([64])
layer2.0.conv1.weight's diag_ggn_exact: torch.Size([128, 64, 3, 3])
layer2.0.bn1.weight's diag_ggn_exact: torch.Size([128])
layer2.0.bn1.bias's diag_ggn_exact: torch.Size([128])
layer2.0.conv2.weight's diag_ggn_exact: torch.Size([128, 128, 3, 3])
layer2.0.bn2.weight's diag_ggn_exact: torch.Size([128])
layer2.0.bn2.bias's diag_ggn_exact: torch.Size([128])
layer2.0.downsample.0.weight's diag_ggn_exact: torch.Size([128, 64, 1, 1])
layer2.0.downsample.1.weight's diag_ggn_exact: torch.Size([128])
layer2.0.downsample.1.bias's diag_ggn_exact: torch.Size([128])
layer2.1.conv1.weight's diag_ggn_exact: torch.Size([128, 128, 3, 3])
layer2.1.bn1.weight's diag_ggn_exact: torch.Size([128])
layer2.1.bn1.bias's diag_ggn_exact: torch.Size([128])
layer2.1.conv2.weight's diag_ggn_exact: torch.Size([128, 128, 3, 3])
layer2.1.bn2.weight's diag_ggn_exact: torch.Size([128])
layer2.1.bn2.bias's diag_ggn_exact: torch.Size([128])
layer3.0.conv1.weight's diag_ggn_exact: torch.Size([256, 128, 3, 3])
layer3.0.bn1.weight's diag_ggn_exact: torch.Size([256])
layer3.0.bn1.bias's diag_ggn_exact: torch.Size([256])
layer3.0.conv2.weight's diag_ggn_exact: torch.Size([256, 256, 3, 3])
layer3.0.bn2.weight's diag_ggn_exact: torch.Size([256])
layer3.0.bn2.bias's diag_ggn_exact: torch.Size([256])
layer3.0.downsample.0.weight's diag_ggn_exact: torch.Size([256, 128, 1, 1])
layer3.0.downsample.1.weight's diag_ggn_exact: torch.Size([256])
layer3.0.downsample.1.bias's diag_ggn_exact: torch.Size([256])
layer3.1.conv1.weight's diag_ggn_exact: torch.Size([256, 256, 3, 3])
layer3.1.bn1.weight's diag_ggn_exact: torch.Size([256])
layer3.1.bn1.bias's diag_ggn_exact: torch.Size([256])
layer3.1.conv2.weight's diag_ggn_exact: torch.Size([256, 256, 3, 3])
layer3.1.bn2.weight's diag_ggn_exact: torch.Size([256])
layer3.1.bn2.bias's diag_ggn_exact: torch.Size([256])
layer4.0.conv1.weight's diag_ggn_exact: torch.Size([512, 256, 3, 3])
layer4.0.bn1.weight's diag_ggn_exact: torch.Size([512])
layer4.0.bn1.bias's diag_ggn_exact: torch.Size([512])
layer4.0.conv2.weight's diag_ggn_exact: torch.Size([512, 512, 3, 3])
layer4.0.bn2.weight's diag_ggn_exact: torch.Size([512])
layer4.0.bn2.bias's diag_ggn_exact: torch.Size([512])
layer4.0.downsample.0.weight's diag_ggn_exact: torch.Size([512, 256, 1, 1])
layer4.0.downsample.1.weight's diag_ggn_exact: torch.Size([512])
layer4.0.downsample.1.bias's diag_ggn_exact: torch.Size([512])
layer4.1.conv1.weight's diag_ggn_exact: torch.Size([512, 512, 3, 3])
layer4.1.bn1.weight's diag_ggn_exact: torch.Size([512])
layer4.1.bn1.bias's diag_ggn_exact: torch.Size([512])
layer4.1.conv2.weight's diag_ggn_exact: torch.Size([512, 512, 3, 3])
layer4.1.bn2.weight's diag_ggn_exact: torch.Size([512])
layer4.1.bn2.bias's diag_ggn_exact: torch.Size([512])
fc.weight's diag_ggn_exact: torch.Size([5, 512])
fc.bias's diag_ggn_exact: torch.Size([5])
Comparison with torch.autograd
:
Note
Computing the full GGN diagonal with PyTorch’s built-in automatic differentiation can be slow, depending on the number of parameters. To reduce run time, we only compare some elements of the diagonal.
num_params = sum(p.numel() for p in model.parameters())
num_to_compare = 10
idx_to_compare = linspace(0, num_params - 1, num_to_compare, device=DEVICE, dtype=int32)
diag_ggn_exact_to_compare = autograd_diag_ggn_exact(
x, y, model, loss_function, idx=idx_to_compare
)
print("Do the exact GGN diagonals match?")
for idx, element in zip(idx_to_compare, diag_ggn_exact_to_compare):
match = allclose(element, diag_ggn_exact_vector[idx], atol=1e-6)
print(f"Diagonal entry {idx:>8}: {match}")
if not match:
raise AssertionError("Exact GGN diagonals don't match!")
Do the exact GGN diagonals match?
Diagonal entry 0: True
Diagonal entry 1242119: True
Diagonal entry 2484239: True
Diagonal entry 3726358: True
Diagonal entry 4968478: True
Diagonal entry 6210597: True
Diagonal entry 7452717: True
Diagonal entry 8694836: True
Diagonal entry 9936956: True
Diagonal entry 11179076: True
Total running time of the script: (0 minutes 3.585 seconds)