Residual networks

There are three different approaches to using BackPACK with ResNets.

  1. Custom ResNet: (Only works for first-order extensions) Write your own model by defining its forward pass. Trainable parameters must be in modules known to BackPACK (e.g. torch.nn.Conv2d, torch.nn.Linear).

  2. Custom ResNet with BackPACK custom modules: (Works for first- and second- order extensions) Build your ResNet with custom modules provided by BackPACK without overwriting the forward pass. This approach is useful if you want to understand how BackPACK handles ResNets, or if you think building a container module that implicitly defines the forward pass is more elegant than coding up a forward pass.

  3. Any ResNet with BackPACK’s converter: (Works for first- and second-order extensions) Convert your model into a BackPACK-compatible architecture.

Note

ResNets are still an experimental feature. Always double-check your results, as done in this example! Open an issue if you encounter a bug to help us improve the support.

Not all extensions support ResNets (yet). Please create a feature request in the repository if the extension you need is not supported.

Let’s get the imports out of the way.

from torch import (
    allclose,
    cat,
    cuda,
    device,
    int32,
    linspace,
    manual_seed,
    rand,
    rand_like,
)
from torch.nn import (
    Conv2d,
    CrossEntropyLoss,
    Flatten,
    Identity,
    Linear,
    Module,
    MSELoss,
    ReLU,
    Sequential,
)
from torch.nn.functional import cross_entropy, relu
from torchvision.models import resnet18

from backpack import backpack, extend
from backpack.custom_module.branching import Parallel, SumModule
from backpack.custom_module.graph_utils import BackpackTracer
from backpack.extensions import BatchGrad, DiagGGNExact
from backpack.utils.examples import autograd_diag_ggn_exact, load_one_batch_mnist

manual_seed(0)
DEVICE = device("cuda:0" if cuda.is_available() else "cpu")
x, y = load_one_batch_mnist(batch_size=32)
x, y = x.to(DEVICE), y.to(DEVICE)

Custom ResNet

We can build a ResNet by extending torch.nn.Module. As long as the layers with parameters (torch.nn.Conv2d and torch.nn.Linear) are nn modules, BackPACK can extend them, and this is all that is needed for first-order extensions. We can rewrite the forward to implement the residual connection, and extend() the resulting model.

Note

Using in-place operations is not compatible with PyTorch’s torch.nn.Module.register_full_backward_hook(). Therefore, always use x = x + residual instead of x += residual.

class MyFirstResNet(Module):
    def __init__(self, C_in=1, C_hid=5, input_dim=(28, 28), output_dim=10):
        """Instantiate submodules that are used in the forward pass."""
        super().__init__()

        self.conv1 = Conv2d(C_in, C_hid, kernel_size=3, stride=1, padding=1)
        self.conv2 = Conv2d(C_hid, C_hid, kernel_size=3, stride=1, padding=1)
        self.linear1 = Linear(input_dim[0] * input_dim[1] * C_hid, output_dim)
        if C_in == C_hid:
            self.shortcut = Identity()
        else:
            self.shortcut = Conv2d(C_in, C_hid, kernel_size=1, stride=1)

    def forward(self, x):
        """Manual implementation of the forward pass."""
        residual = self.shortcut(x)
        x = self.conv2(relu(self.conv1(x)))
        x = x + residual  # don't use: x += residual
        x = x.flatten(start_dim=1)
        x = self.linear1(x)
        return x


model = extend(MyFirstResNet()).to(DEVICE)

The loss does not need to be extended in this case either, as it does not have model parameters and BackPACK does not need to know about it for first-order extensions. This also means you can use any custom loss function.

Using BatchGrad in a with backpack(...) block, we can access the individual gradients for each sample.

loss = cross_entropy(model(x), y, reduction="sum")

with backpack(BatchGrad()):
    loss.backward()

for name, parameter in model.named_parameters():
    print(f"{name:>20}'s grad_batch shape: {parameter.grad_batch.shape}")

Out:

/home/docs/checkouts/readthedocs.org/user_builds/backpack/envs/master/lib/python3.7/site-packages/backpack/extensions/backprop_extension.py:107: UserWarning: Extension saving to grad_batch does not have an extension for Module <class '__main__.MyFirstResNet'> although the module has parameters
  f"Extension saving to {self.savefield} does not have an "
        conv1.weight's grad_batch shape: torch.Size([32, 5, 1, 3, 3])
          conv1.bias's grad_batch shape: torch.Size([32, 5])
        conv2.weight's grad_batch shape: torch.Size([32, 5, 5, 3, 3])
          conv2.bias's grad_batch shape: torch.Size([32, 5])
      linear1.weight's grad_batch shape: torch.Size([32, 10, 3920])
        linear1.bias's grad_batch shape: torch.Size([32, 10])
     shortcut.weight's grad_batch shape: torch.Size([32, 5, 1, 1, 1])
       shortcut.bias's grad_batch shape: torch.Size([32, 5])

To check that everything works, let’s compute one individual gradient with PyTorch (using a single sample in a forward and backward pass) and compare it with the one computed by BackPACK.

sample_to_check = 1
x_to_check = x[[sample_to_check]]
y_to_check = y[[sample_to_check]]

model.zero_grad()
loss = cross_entropy(model(x_to_check), y_to_check)
loss.backward()

print("Do the individual gradients match?")
for name, parameter in model.named_parameters():
    match = allclose(parameter.grad_batch[sample_to_check], parameter.grad, atol=1e-6)
    print(f"{name:>20}: {match}")
    if not match:
        raise AssertionError("Individual gradients don't match!")

Out:

Do the individual gradients match?
        conv1.weight: True
          conv1.bias: True
        conv2.weight: True
          conv2.bias: True
      linear1.weight: True
        linear1.bias: True
     shortcut.weight: True
       shortcut.bias: True

Custom ResNet with BackPACK custom modules

Second-order extensions only work if every node in the computation graph is an nn module that can be extended by BackPACK. The above ResNet class MyFirstResNet does not satisfy these conditions, because it implements the skip connection via torch.add() while overwriting the forward() method.

To build ResNets without overwriting the forward pass, BackPACK offers custom modules:

  1. Parallel is similar to torch.nn.Sequential, but implements a container for a parallel sequence of modules (followed by an aggregation module), rather than a sequential one.

  2. SumModule is the module that takes the role of torch.add() in the previous example. It sums up multiple inputs. We will use it to merge the skip connection.

With the above modules, we can build a simple ResNet as a container that implicitly defines the forward pass:

C_in = 1
C_hid = 2
input_dim = (28, 28)
output_dim = 10

model = Sequential(
    Conv2d(C_in, C_hid, kernel_size=3, stride=1, padding=1),
    ReLU(),
    Parallel(  # skip connection with ReLU-activated convolution
        Identity(),
        Sequential(
            Conv2d(C_hid, C_hid, kernel_size=3, stride=1, padding=1),
            ReLU(),
        ),
        merge_module=SumModule(),
    ),
    Flatten(),
    Linear(input_dim[0] * input_dim[1] * C_hid, output_dim),
)

model = extend(model.to(DEVICE))
loss_function = extend(CrossEntropyLoss(reduction="mean")).to(DEVICE)

This ResNets supports BackPACK’s second-order extensions:

loss = loss_function(model(x), y)

with backpack(DiagGGNExact()):
    loss.backward()

for name, parameter in model.named_parameters():
    print(f"{name}'s diag_ggn_exact: {parameter.diag_ggn_exact.shape}")

diag_ggn_exact_vector = cat([p.diag_ggn_exact.flatten() for p in model.parameters()])

Out:

0.weight's diag_ggn_exact: torch.Size([2, 1, 3, 3])
0.bias's diag_ggn_exact: torch.Size([2])
2.branch.1.0.weight's diag_ggn_exact: torch.Size([2, 2, 3, 3])
2.branch.1.0.bias's diag_ggn_exact: torch.Size([2])
4.weight's diag_ggn_exact: torch.Size([10, 1568])
4.bias's diag_ggn_exact: torch.Size([10])

Comparison with torch.autograd:

Note

Computing the full GGN diagonal with PyTorch’s built-in automatic differentiation can be slow, depending on the number of parameters. To reduce run time, we only compare some elements of the diagonal.

num_params = sum(p.numel() for p in model.parameters())
num_to_compare = 10
idx_to_compare = linspace(0, num_params - 1, num_to_compare, device=DEVICE, dtype=int32)

diag_ggn_exact_to_compare = autograd_diag_ggn_exact(
    x, y, model, loss_function, idx=idx_to_compare
)

print("Do the exact GGN diagonals match?")
for idx, element in zip(idx_to_compare, diag_ggn_exact_to_compare):
    match = allclose(element, diag_ggn_exact_vector[idx], atol=1e-6)
    print(f"Diagonal entry {idx:>6}: {match}")
    if not match:
        raise AssertionError("Exact GGN diagonals don't match!")

Out:

Do the exact GGN diagonals match?
Diagonal entry      0: True
Diagonal entry   1749: True
Diagonal entry   3499: True
Diagonal entry   5249: True
Diagonal entry   6998: True
Diagonal entry   8748: True
Diagonal entry  10498: True
Diagonal entry  12247: True
Diagonal entry  13997: True
Diagonal entry  15747: True

Any ResNet with BackPACK’s converter

If you are not building a ResNet through custom modules but for instance want to use a prominent ResNet from torchvision.models, BackPACK offers a converter. It analyzes the model and tries to turn it into a compatible architecture. The result is a torch.fx.GraphModule that exclusively consists of modules.

Here, we demo the converter on resnet18.

Note

resnet18 has to be in evaluation mode, because it contains batch normalization layers that are not supported in train mode by the second-order extension used in this example.

Let’s create the model, and convert it in the call to extend:

loss_function = extend(MSELoss().to(DEVICE))
model = resnet18(num_classes=5).to(DEVICE).eval()

# use BackPACK's converter to extend the model (turned off by default)
model = extend(model, use_converter=True)

To get an understanding what happened, we can inspect the model’s graph with the following helper function:

def print_table(module: Module) -> None:
    """Prints a table of the module.

    Args:
        module: module to analyze
    """
    graph = BackpackTracer().trace(module)
    graph.print_tabular()


print_table(model)

Out:

opcode       name                   target                 args                                   kwargs
-----------  ---------------------  ---------------------  -------------------------------------  --------
placeholder  x                      x                      ()                                     {}
call_module  conv1                  conv1                  (x,)                                   {}
call_module  bn1                    bn1                    (conv1,)                               {}
call_module  relu                   relu                   (bn1,)                                 {}
call_module  maxpool                maxpool                (relu,)                                {}
call_module  layer1_0_conv1         layer1.0.conv1         (maxpool,)                             {}
call_module  layer1_0_bn1           layer1.0.bn1           (layer1_0_conv1,)                      {}
call_module  layer1_0_relu0         layer1.0.relu0         (layer1_0_bn1,)                        {}
call_module  layer1_0_conv2         layer1.0.conv2         (layer1_0_relu0,)                      {}
call_module  layer1_0_bn2           layer1.0.bn2           (layer1_0_conv2,)                      {}
call_module  sum_module0            sum_module0            (layer1_0_bn2, maxpool)                {}
call_module  layer1_0_relu1         layer1.0.relu1         (sum_module0,)                         {}
call_module  layer1_1_conv1         layer1.1.conv1         (layer1_0_relu1,)                      {}
call_module  layer1_1_bn1           layer1.1.bn1           (layer1_1_conv1,)                      {}
call_module  layer1_1_relu0         layer1.1.relu0         (layer1_1_bn1,)                        {}
call_module  layer1_1_conv2         layer1.1.conv2         (layer1_1_relu0,)                      {}
call_module  layer1_1_bn2           layer1.1.bn2           (layer1_1_conv2,)                      {}
call_module  sum_module1            sum_module1            (layer1_1_bn2, layer1_0_relu1)         {}
call_module  layer1_1_relu1         layer1.1.relu1         (sum_module1,)                         {}
call_module  layer2_0_conv1         layer2.0.conv1         (layer1_1_relu1,)                      {}
call_module  layer2_0_bn1           layer2.0.bn1           (layer2_0_conv1,)                      {}
call_module  layer2_0_relu0         layer2.0.relu0         (layer2_0_bn1,)                        {}
call_module  layer2_0_conv2         layer2.0.conv2         (layer2_0_relu0,)                      {}
call_module  layer2_0_bn2           layer2.0.bn2           (layer2_0_conv2,)                      {}
call_module  layer2_0_downsample_0  layer2.0.downsample.0  (layer1_1_relu1,)                      {}
call_module  layer2_0_downsample_1  layer2.0.downsample.1  (layer2_0_downsample_0,)               {}
call_module  sum_module2            sum_module2            (layer2_0_bn2, layer2_0_downsample_1)  {}
call_module  layer2_0_relu1         layer2.0.relu1         (sum_module2,)                         {}
call_module  layer2_1_conv1         layer2.1.conv1         (layer2_0_relu1,)                      {}
call_module  layer2_1_bn1           layer2.1.bn1           (layer2_1_conv1,)                      {}
call_module  layer2_1_relu0         layer2.1.relu0         (layer2_1_bn1,)                        {}
call_module  layer2_1_conv2         layer2.1.conv2         (layer2_1_relu0,)                      {}
call_module  layer2_1_bn2           layer2.1.bn2           (layer2_1_conv2,)                      {}
call_module  sum_module3            sum_module3            (layer2_1_bn2, layer2_0_relu1)         {}
call_module  layer2_1_relu1         layer2.1.relu1         (sum_module3,)                         {}
call_module  layer3_0_conv1         layer3.0.conv1         (layer2_1_relu1,)                      {}
call_module  layer3_0_bn1           layer3.0.bn1           (layer3_0_conv1,)                      {}
call_module  layer3_0_relu0         layer3.0.relu0         (layer3_0_bn1,)                        {}
call_module  layer3_0_conv2         layer3.0.conv2         (layer3_0_relu0,)                      {}
call_module  layer3_0_bn2           layer3.0.bn2           (layer3_0_conv2,)                      {}
call_module  layer3_0_downsample_0  layer3.0.downsample.0  (layer2_1_relu1,)                      {}
call_module  layer3_0_downsample_1  layer3.0.downsample.1  (layer3_0_downsample_0,)               {}
call_module  sum_module4            sum_module4            (layer3_0_bn2, layer3_0_downsample_1)  {}
call_module  layer3_0_relu1         layer3.0.relu1         (sum_module4,)                         {}
call_module  layer3_1_conv1         layer3.1.conv1         (layer3_0_relu1,)                      {}
call_module  layer3_1_bn1           layer3.1.bn1           (layer3_1_conv1,)                      {}
call_module  layer3_1_relu0         layer3.1.relu0         (layer3_1_bn1,)                        {}
call_module  layer3_1_conv2         layer3.1.conv2         (layer3_1_relu0,)                      {}
call_module  layer3_1_bn2           layer3.1.bn2           (layer3_1_conv2,)                      {}
call_module  sum_module5            sum_module5            (layer3_1_bn2, layer3_0_relu1)         {}
call_module  layer3_1_relu1         layer3.1.relu1         (sum_module5,)                         {}
call_module  layer4_0_conv1         layer4.0.conv1         (layer3_1_relu1,)                      {}
call_module  layer4_0_bn1           layer4.0.bn1           (layer4_0_conv1,)                      {}
call_module  layer4_0_relu0         layer4.0.relu0         (layer4_0_bn1,)                        {}
call_module  layer4_0_conv2         layer4.0.conv2         (layer4_0_relu0,)                      {}
call_module  layer4_0_bn2           layer4.0.bn2           (layer4_0_conv2,)                      {}
call_module  layer4_0_downsample_0  layer4.0.downsample.0  (layer3_1_relu1,)                      {}
call_module  layer4_0_downsample_1  layer4.0.downsample.1  (layer4_0_downsample_0,)               {}
call_module  sum_module6            sum_module6            (layer4_0_bn2, layer4_0_downsample_1)  {}
call_module  layer4_0_relu1         layer4.0.relu1         (sum_module6,)                         {}
call_module  layer4_1_conv1         layer4.1.conv1         (layer4_0_relu1,)                      {}
call_module  layer4_1_bn1           layer4.1.bn1           (layer4_1_conv1,)                      {}
call_module  layer4_1_relu0         layer4.1.relu0         (layer4_1_bn1,)                        {}
call_module  layer4_1_conv2         layer4.1.conv2         (layer4_1_relu0,)                      {}
call_module  layer4_1_bn2           layer4.1.bn2           (layer4_1_conv2,)                      {}
call_module  sum_module7            sum_module7            (layer4_1_bn2, layer4_0_relu1)         {}
call_module  layer4_1_relu1         layer4.1.relu1         (sum_module7,)                         {}
call_module  avgpool                avgpool                (layer4_1_relu1,)                      {}
call_module  flatten0               flatten0               (avgpool,)                             {}
call_module  fc                     fc                     (flatten0,)                            {}
output       output                 output                 (fc,)                                  {}

Admittedly, the converted resnet18’s graph is quite large. Note however that it fully consists of modules (indicated by call_module in the first table column) such that BackPACK’s hooks can successfully backpropagate additional information for its second-order extensions (first-order extensions work, too).

Let’s verify that second-order extensions are working:

x = rand(4, 3, 7, 7, device=DEVICE)  # (128, 3, 224, 224)
output = model(x)
y = rand_like(output)

loss = loss_function(output, y)

with backpack(DiagGGNExact()):
    loss.backward()

for name, parameter in model.named_parameters():
    print(f"{name}'s diag_ggn_exact: {parameter.diag_ggn_exact.shape}")

diag_ggn_exact_vector = cat([p.diag_ggn_exact.flatten() for p in model.parameters()])

Out:

conv1.weight's diag_ggn_exact: torch.Size([64, 3, 7, 7])
bn1.weight's diag_ggn_exact: torch.Size([64])
bn1.bias's diag_ggn_exact: torch.Size([64])
layer1.0.conv1.weight's diag_ggn_exact: torch.Size([64, 64, 3, 3])
layer1.0.bn1.weight's diag_ggn_exact: torch.Size([64])
layer1.0.bn1.bias's diag_ggn_exact: torch.Size([64])
layer1.0.conv2.weight's diag_ggn_exact: torch.Size([64, 64, 3, 3])
layer1.0.bn2.weight's diag_ggn_exact: torch.Size([64])
layer1.0.bn2.bias's diag_ggn_exact: torch.Size([64])
layer1.1.conv1.weight's diag_ggn_exact: torch.Size([64, 64, 3, 3])
layer1.1.bn1.weight's diag_ggn_exact: torch.Size([64])
layer1.1.bn1.bias's diag_ggn_exact: torch.Size([64])
layer1.1.conv2.weight's diag_ggn_exact: torch.Size([64, 64, 3, 3])
layer1.1.bn2.weight's diag_ggn_exact: torch.Size([64])
layer1.1.bn2.bias's diag_ggn_exact: torch.Size([64])
layer2.0.conv1.weight's diag_ggn_exact: torch.Size([128, 64, 3, 3])
layer2.0.bn1.weight's diag_ggn_exact: torch.Size([128])
layer2.0.bn1.bias's diag_ggn_exact: torch.Size([128])
layer2.0.conv2.weight's diag_ggn_exact: torch.Size([128, 128, 3, 3])
layer2.0.bn2.weight's diag_ggn_exact: torch.Size([128])
layer2.0.bn2.bias's diag_ggn_exact: torch.Size([128])
layer2.0.downsample.0.weight's diag_ggn_exact: torch.Size([128, 64, 1, 1])
layer2.0.downsample.1.weight's diag_ggn_exact: torch.Size([128])
layer2.0.downsample.1.bias's diag_ggn_exact: torch.Size([128])
layer2.1.conv1.weight's diag_ggn_exact: torch.Size([128, 128, 3, 3])
layer2.1.bn1.weight's diag_ggn_exact: torch.Size([128])
layer2.1.bn1.bias's diag_ggn_exact: torch.Size([128])
layer2.1.conv2.weight's diag_ggn_exact: torch.Size([128, 128, 3, 3])
layer2.1.bn2.weight's diag_ggn_exact: torch.Size([128])
layer2.1.bn2.bias's diag_ggn_exact: torch.Size([128])
layer3.0.conv1.weight's diag_ggn_exact: torch.Size([256, 128, 3, 3])
layer3.0.bn1.weight's diag_ggn_exact: torch.Size([256])
layer3.0.bn1.bias's diag_ggn_exact: torch.Size([256])
layer3.0.conv2.weight's diag_ggn_exact: torch.Size([256, 256, 3, 3])
layer3.0.bn2.weight's diag_ggn_exact: torch.Size([256])
layer3.0.bn2.bias's diag_ggn_exact: torch.Size([256])
layer3.0.downsample.0.weight's diag_ggn_exact: torch.Size([256, 128, 1, 1])
layer3.0.downsample.1.weight's diag_ggn_exact: torch.Size([256])
layer3.0.downsample.1.bias's diag_ggn_exact: torch.Size([256])
layer3.1.conv1.weight's diag_ggn_exact: torch.Size([256, 256, 3, 3])
layer3.1.bn1.weight's diag_ggn_exact: torch.Size([256])
layer3.1.bn1.bias's diag_ggn_exact: torch.Size([256])
layer3.1.conv2.weight's diag_ggn_exact: torch.Size([256, 256, 3, 3])
layer3.1.bn2.weight's diag_ggn_exact: torch.Size([256])
layer3.1.bn2.bias's diag_ggn_exact: torch.Size([256])
layer4.0.conv1.weight's diag_ggn_exact: torch.Size([512, 256, 3, 3])
layer4.0.bn1.weight's diag_ggn_exact: torch.Size([512])
layer4.0.bn1.bias's diag_ggn_exact: torch.Size([512])
layer4.0.conv2.weight's diag_ggn_exact: torch.Size([512, 512, 3, 3])
layer4.0.bn2.weight's diag_ggn_exact: torch.Size([512])
layer4.0.bn2.bias's diag_ggn_exact: torch.Size([512])
layer4.0.downsample.0.weight's diag_ggn_exact: torch.Size([512, 256, 1, 1])
layer4.0.downsample.1.weight's diag_ggn_exact: torch.Size([512])
layer4.0.downsample.1.bias's diag_ggn_exact: torch.Size([512])
layer4.1.conv1.weight's diag_ggn_exact: torch.Size([512, 512, 3, 3])
layer4.1.bn1.weight's diag_ggn_exact: torch.Size([512])
layer4.1.bn1.bias's diag_ggn_exact: torch.Size([512])
layer4.1.conv2.weight's diag_ggn_exact: torch.Size([512, 512, 3, 3])
layer4.1.bn2.weight's diag_ggn_exact: torch.Size([512])
layer4.1.bn2.bias's diag_ggn_exact: torch.Size([512])
fc.weight's diag_ggn_exact: torch.Size([5, 512])
fc.bias's diag_ggn_exact: torch.Size([5])

Comparison with torch.autograd:

Note

Computing the full GGN diagonal with PyTorch’s built-in automatic differentiation can be slow, depending on the number of parameters. To reduce run time, we only compare some elements of the diagonal.

num_params = sum(p.numel() for p in model.parameters())
num_to_compare = 10
idx_to_compare = linspace(0, num_params - 1, num_to_compare, device=DEVICE, dtype=int32)

diag_ggn_exact_to_compare = autograd_diag_ggn_exact(
    x, y, model, loss_function, idx=idx_to_compare
)

print("Do the exact GGN diagonals match?")
for idx, element in zip(idx_to_compare, diag_ggn_exact_to_compare):
    match = allclose(element, diag_ggn_exact_vector[idx], atol=1e-6)
    print(f"Diagonal entry {idx:>8}: {match}")
    if not match:
        raise AssertionError("Exact GGN diagonals don't match!")

Out:

Do the exact GGN diagonals match?
Diagonal entry        0: True
Diagonal entry  1242119: True
Diagonal entry  2484239: True
Diagonal entry  3726358: True
Diagonal entry  4968478: True
Diagonal entry  6210597: True
Diagonal entry  7452717: True
Diagonal entry  8694836: True
Diagonal entry  9936956: True
Diagonal entry 11179076: True

Total running time of the script: ( 0 minutes 3.249 seconds)

Gallery generated by Sphinx-Gallery