First order extensions with a ResNet

Let’s get the imports, configuration and some helper functions out of the way first.

import torch
import torch.nn.functional as F

from backpack import backpack, extend
from backpack.extensions import BatchGrad
from backpack.utils.examples import load_one_batch_mnist

BATCH_SIZE = 3
torch.manual_seed(0)
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


def get_accuracy(output, targets):
    """Helper function to print the accuracy"""
    predictions = output.argmax(dim=1, keepdim=True).view_as(targets)
    return predictions.eq(targets).float().mean().item()


x, y = load_one_batch_mnist(batch_size=BATCH_SIZE)
x, y = x.to(DEVICE), y.to(DEVICE)

We can build a ResNet by extending torch.nn.Module. As long as the layers with parameters (torch.nn.Conv2d and torch.nn.Linear) are nn modules, BackPACK can extend them, and this is all that is needed for first order extensions. We can rewrite the forward to implement the residual connection, and extend() the resulting model.

class MyFirstResNet(torch.nn.Module):
    def __init__(self, C_in=1, C_hid=5, input_dim=(28, 28), output_dim=10):
        super().__init__()

        self.conv1 = torch.nn.Conv2d(C_in, C_hid, kernel_size=3, stride=1, padding=1)
        self.conv2 = torch.nn.Conv2d(C_hid, C_hid, kernel_size=3, stride=1, padding=1)
        self.linear1 = torch.nn.Linear(input_dim[0] * input_dim[1] * C_hid, output_dim)
        if C_in == C_hid:
            self.shortcut = torch.nn.Identity()
        else:
            self.shortcut = torch.nn.Conv2d(C_in, C_hid, kernel_size=1, stride=1)

    def forward(self, x):
        residual = self.shortcut(x)
        x = self.conv2(F.relu(self.conv1(x)))
        x += residual
        x = x.view(x.size(0), -1)
        x = self.linear1(x)
        return x


model = extend(MyFirstResNet()).to(DEVICE)

Using BatchGrad in a with backpack(...) block, we can access the individual gradients for each sample.

The loss does not need to be extended in this case either, as it does not have model parameters and BackPACK does not need to know about it for first order extensions. This also means you can use any custom loss function.

model.zero_grad()
loss = F.cross_entropy(model(x), y, reduction="sum")
with backpack(BatchGrad()):
    loss.backward()

print("{:<20}  {:<30} {:<30}".format("Param", "grad", "grad (batch)"))
print("-" * 80)
for name, p in model.named_parameters():
    print(
        "{:<20}: {:<30} {:<30}".format(name, str(p.grad.shape), str(p.grad_batch.shape))
    )

Out:

/home/docs/checkouts/readthedocs.org/user_builds/backpack/envs/1.3.0/lib/python3.7/site-packages/torch/nn/modules/module.py:974: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.
  warnings.warn("Using a non-full backward hook when the forward contains multiple autograd Nodes "
Param                 grad                           grad (batch)
--------------------------------------------------------------------------------
conv1.weight        : torch.Size([5, 1, 3, 3])       torch.Size([3, 5, 1, 3, 3])
conv1.bias          : torch.Size([5])                torch.Size([3, 5])
conv2.weight        : torch.Size([5, 5, 3, 3])       torch.Size([3, 5, 5, 3, 3])
conv2.bias          : torch.Size([5])                torch.Size([3, 5])
linear1.weight      : torch.Size([10, 3920])         torch.Size([3, 10, 3920])
linear1.bias        : torch.Size([10])               torch.Size([3, 10])
shortcut.weight     : torch.Size([5, 1, 1, 1])       torch.Size([3, 5, 1, 1, 1])
shortcut.bias       : torch.Size([5])                torch.Size([3, 5])

To check that everything works, let’s compute one individual gradient with PyTorch (using a single sample in a forward and backward pass) and compare it with the one computed by BackPACK.

sample_to_check = 1
x_to_check = x[sample_to_check, :].unsqueeze(0)
y_to_check = y[sample_to_check].unsqueeze(0)

model.zero_grad()
loss = F.cross_entropy(model(x_to_check), y_to_check)
loss.backward()

print("Do the individual gradients match?")
for name, p in model.named_parameters():
    match = torch.allclose(p.grad_batch[sample_to_check, :], p.grad, atol=1e-7)
    print("{:<20} {}".format(name, match))

Out:

/home/docs/checkouts/readthedocs.org/user_builds/backpack/envs/1.3.0/lib/python3.7/site-packages/torch/nn/modules/module.py:974: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.
  warnings.warn("Using a non-full backward hook when the forward contains multiple autograd Nodes "
Do the individual gradients match?
conv1.weight         True
conv1.bias           True
conv2.weight         True
conv2.bias           True
linear1.weight       True
linear1.bias         True
shortcut.weight      True
shortcut.bias        True

Total running time of the script: ( 0 minutes 0.093 seconds)

Gallery generated by Sphinx-Gallery