Custom module example

This tutorial shows how to support a custom module in a simple fashion. We focus on BackPACK’s first-order extensions. They don’t backpropagate additional information and thus require less functionality be implemented.

Let’s get the imports out of our way.

import torch

from backpack import backpack, extend
from backpack.extensions import BatchGrad
from backpack.extensions.firstorder.base import FirstOrderModuleExtension

# make deterministic
torch.manual_seed(0)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Custom PyTorch module

In this example, we consider extending our own, very simplistic, layer. It scales the input by a scalar weight in a forward pass. Here is the layer class (see https://pytorch.org/docs/stable/notes/extending.html).

class ScaleModule(torch.nn.Module):
    """Defines the module."""

    def __init__(self, weight=2.0):
        """Store scalar weight.

        Args:
            weight(float, optional): Initial value for weight. Defaults to 2.0.
        """
        super(ScaleModule, self).__init__()

        self.weight = torch.nn.Parameter(torch.tensor([weight]))

    def forward(self, input):
        """Defines forward pass.

        Args:
            input(torch.Tensor): input

        Returns:
            torch.Tensor: product of input and weight
        """
        return input * self.weight

You don’t necessarily need to write a custom layer. Any PyTorch layer can be extended as described (it should be a torch.nn.Module’s because BackPACK uses module hooks).

Custom module extension

Let’s make BackPACK support computing individual gradients for ScaleModule. This is done by the BatchGrad extension. To support the new module, we need to create a module extension that implements how individual gradients are extracted with respect to ScaleModule’s parameter.

The module extension must implement methods named after the parameters passed to the constructor. Here it goes.

class ScaleModuleBatchGrad(FirstOrderModuleExtension):
    """Extract individual gradients for ``ScaleModule``."""

    def __init__(self):
        """Store parameters for which individual gradients should be computed."""
        # specify parameter names
        super().__init__(params=["weight"])

    def weight(self, ext, module, g_inp, g_out, bpQuantities):
        """Extract individual gradients for ScaleModule's ``weight`` parameter.

        Args:
            ext(BatchGrad): extension that is used
            module(ScaleModule): module that performed forward pass
            g_inp(tuple[torch.Tensor]): input gradient tensors
            g_out(tuple[torch.Tensor]): output gradient tensors
            bpQuantities(None): additional quantities for second-order

        Returns:
            torch.Tensor: individual gradients
        """
        show_useful = True

        if show_useful:
            print("Useful quantities:")
            # output is saved under field output
            print("\tmodule.output.shape:", module.output.shape)
            # input i is saved under field input[i]
            print("\tmodule.input0.shape:", module.input0.shape)
            # gradient w.r.t output
            print("\tg_out[0].shape:     ", g_out[0].shape)

        # actual computation
        return (g_out[0] * module.input0).flatten(start_dim=1).sum(axis=1).unsqueeze(-1)

Lastly, we need to register the mapping between layer (ScaleModule) and layer extension (ScaleModuleBatchGrad) in an instance of BatchGrad.

# register module-computation mapping
extension = BatchGrad()
extension.set_module_extension(ScaleModule, ScaleModuleBatchGrad())

That’s it. We can now pass extension to a with backpack(...) context and compute individual gradients with respect to ScaleModule’s weight parameter.

Test custom module

Here, we verify the custom module extension on a small net with random inputs. Let’s create these.

batch_size = 10
batch_axis = 0
input_size = 4

inputs = torch.randn(batch_size, input_size, device=device)
targets = torch.randint(0, 2, (batch_size,), device=device)

reduction = ["mean", "sum"][1]
my_module = ScaleModule().to(device)
lossfunc = torch.nn.CrossEntropyLoss(reduction=reduction).to(device)

Note

Results of "mean" and "sum" reduction differ by a scaling factor, because the information backpropagated by PyTorch is scaled. This is documented at https://docs.backpack.pt/en/master/extensions.html#backpack.extensions.BatchGrad.

Individual gradients with PyTorch

The following computes individual gradients by looping over individual samples and stacking their gradients.

grad_batch_autograd = []

for input_n, target_n in zip(
    inputs.split(1, dim=batch_axis), targets.split(1, dim=batch_axis)
):
    loss_n = lossfunc(my_module(input_n), target_n)
    grad_n = torch.autograd.grad(loss_n, [my_module.weight])[0]
    grad_batch_autograd.append(grad_n)

grad_batch_autograd = torch.stack(grad_batch_autograd)

print("weight.shape:             ", my_module.weight.shape)
print("grad_batch_autograd.shape:", grad_batch_autograd.shape)

Out:

weight.shape:              torch.Size([1])
grad_batch_autograd.shape: torch.Size([10, 1])

Individual gradients with BackPACK

BackPACK can compute individual gradients in a single backward pass. First, extend model and loss function, then perform the backpropagation inside a with backpack(...) context.

my_module = extend(my_module)
lossfunc = extend(lossfunc)

loss = lossfunc(my_module(inputs), targets)

with backpack(extension):
    loss.backward()

grad_batch_backpack = my_module.weight.grad_batch

print("weight.shape:             ", my_module.weight.shape)
print("grad_batch_backpack.shape:", grad_batch_backpack.shape)

Out:

Useful quantities:
        module.output.shape: torch.Size([10, 4])
        module.input0.shape: torch.Size([10, 4])
        g_out[0].shape:      torch.Size([10, 4])
weight.shape:              torch.Size([1])
grad_batch_backpack.shape: torch.Size([10, 1])

Do the computation results match?

match = torch.allclose(grad_batch_autograd, grad_batch_backpack)

print(f"autograd and BackPACK individual gradients match? {match}")

if not match:
    raise AssertionError(
        "Individual gradients don't match:"
        + f"\n{grad_batch_autograd}\nvs.\n{grad_batch_backpack}"
    )

Out:

autograd and BackPACK individual gradients match? True

Total running time of the script: ( 0 minutes 0.011 seconds)

Gallery generated by Sphinx-Gallery