BackPACK’s retain_graph option

This tutorial demonstrates how to perform multiple backward passes through the same computation graph with BackPACK. This option can be useful if you run into out-of-memory errors. If your computation can be chunked, you might consider distributing it onto multiple backward passes to reduce peak memory.

Our use case for such a quantity is the GGN diagonal of an auto-encoder’s reconstruction error.

But first, the imports:

from functools import partial
from time import time
from typing import List

from memory_profiler import memory_usage
from torch import Tensor, allclose, manual_seed, rand, zeros_like
from torch.nn import Conv2d, ConvTranspose2d, Flatten, MSELoss, Sequential, Sigmoid

from backpack import backpack, extend
from backpack.custom_module.slicing import Slicing
from backpack.extensions import DiagGGNExact

# make deterministic
manual_seed(0)
<torch._C.Generator object at 0x7eff044e0a90>

Setup

Let \(f_{\mathbf{\theta}}\) denote the auto-encoder, and \(\mathbf{x'} = f_{\mathbf{\theta}}(\mathbf{x}) \in \mathbb{R}^M\) its reconstruction of an input \(\mathbf{x} \in \mathbb{R}^M\). The associated reconstruction error is measured by the mean squared error

\[\ell(\mathbf{\theta}) = \frac{1}{M} \left\lVert f_{\mathbf{\theta}}(\mathbf{x}) - \mathbf{x} \right\rVert^2_2\,.\]

On a batch of \(N\) examples, \(\mathbf{x}_1, \dots, \mathbf{x}_N\), the loss is

\[\mathcal{L}(\mathbf{\theta}) = \frac{1}{N} \frac{1}{M} \sum_{n=1}^N \left\lVert f_{\mathbf{\theta}}(\mathbf{x}_n) - \mathbf{x}_n \right\rVert^2_2\,.\]

Let’s create a toy model and data:

# data
batch_size, channels, spatial_dims = 5, 3, (32, 32)
X = rand(batch_size, channels, *spatial_dims)

# model (auto-encoder)
hidden_channels = 10

encoder = Sequential(
    Conv2d(channels, hidden_channels, 3),
    Sigmoid(),
)
decoder = Sequential(
    ConvTranspose2d(hidden_channels, channels, 3),
    Flatten(),
)
model = Sequential(
    encoder,
    decoder,
)
loss_func = MSELoss()

We will use BackPACK to compute the GGN diagonal of the mini-batch loss. To do that, we need to extend model and loss function.

model = extend(model)
loss_func = extend(loss_func)

GGN diagonal in one backward pass

As usual, we can compute the GGN diagonal for the mini-batch loss in a single backward pass. The following function does that:

def diag_ggn_one_pass() -> List[Tensor]:
    """Compute the GGN diagonal in a single backward pass.

    Returns:
        GGN diagonal in parameter list format.
    """
    reconstruction = model(X)
    error = loss_func(reconstruction, X.flatten(start_dim=1))

    with backpack(DiagGGNExact()):
        error.backward()

    return [p.diag_ggn_exact.clone() for p in model.parameters() if p.requires_grad]

Let’s run it and determine (i) its peak memory consumption and (ii) its run time.

print("GGN diagonal in one backward pass:")
start = time()
max_mem, diag_ggn = memory_usage(
    diag_ggn_one_pass, interval=1e-3, max_usage=True, retval=True
)
end = time()

print(f"\tPeak memory [MiB]: {max_mem:.2e}")
print(f"\tTime [s]: {end-start:.2e}")
GGN diagonal in one backward pass:
        Peak memory [MiB]: 2.42e+03
        Time [s]: 3.94e+00

The memory consumption is pretty high, although our model is relatively small! If we make the model deeper, or increase the mini-batch size, we will quickly run out of memory.

This is because computing the GGN diagonal scales with the network’s output dimension. For classification settings like MNIST and CIFAR-10, this number is relatively small (10). But for an auto-encoder, this number is the input dimension M, which in our case is

print(f"Output dimension: {model(X).shape[1:].numel()}")
Output dimension: 3072

We will now take a look at how to circumvent the high peak memory by distributing the computation over multiple backward passes.

GGN diagonal in chunks

The GGN diagonal computation can be distributed across multiple backward passes. This greatly reduces peak memory consumption.

To see this, let’s consider the GGN diagonal for a single example \(\mathbf{x}\),

\[\mathrm{diag} \left( \left[ \mathbf{J}_{\mathbf{\theta}} f_{\mathbf{\theta}}(\mathbf{x}) \right]^\top \left[ \frac{2}{M} \mathbf{I}_{M\times M} \right] \mathbf{J}_{\mathbf{\theta}} f_{\mathbf{\theta}}(\mathbf{x}) \right)\,,\]

with the \(M \times |\mathbf{\theta}|\) Jacobian \(\mathbf{J}_{\mathbf{\theta}} f_{\mathbf{\theta}}(\mathbf{x})\) of the model, and \(\frac{2}{M} \mathbf{I}_{M\times M}\) the mean squared error’s Hessian w.r.t. the reconstructed input. Here you can see that the memory consumption scales with the output dimension, as we need to compute M vector-Jacobian products.

Let \(S\), the chunk size, be a number that divides the output dimension \(M\). Then, we can decompose the above computation into chunks:

\[\begin{split}\frac{S}{M} \left\{ \mathrm{diag} \left( \left[ \mathbf{J}_{\mathbf{\theta}} f_{\mathbf{\theta}}(\mathbf{x}) \right]^\top_{:, 0:S} \left[ \frac{2}{S} \mathbf{I}_{S\times S} \right] \left[ \mathbf{J}_{\mathbf{\theta}} f_{\mathbf{\theta}}(\mathbf{x}) \right]_{0:S, :} \right) \right. \\ + \left. \mathrm{diag} \left( \left[ \mathbf{J}_{\mathbf{\theta}} f_{\mathbf{\theta}}(\mathbf{x}) \right]^\top_{:, S: 2S} \left[ \frac{2}{S} \mathbf{I}_{S\times S} \right] \left[ \mathbf{J}_{\mathbf{\theta}} f_{\mathbf{\theta}}(\mathbf{x}) \right]_{:, S:2S} \right) \right. \\ + \left. \mathrm{diag} \left( \left[ \mathbf{J}_{\mathbf{\theta}} f_{\mathbf{\theta}}(\mathbf{x}) \right]^\top_{:, 2S: 3S} \left[ \frac{2}{S} \mathbf{I}_{S\times S} \right] \left[ \mathbf{J}_{\mathbf{\theta}} f_{\mathbf{\theta}}(\mathbf{x}) \right]_{:, 2S: 3S} \right) + \dots \right\}\,.\end{split}\]

Each summand is the GGN diagonal of the mean squared error on a chunk

\[\tilde{\ell}(\mathbf{\theta}) = \frac{1}{S} \lVert [f_{\mathbf{\theta}}(\mathbf{x})]_{i S: (i+1) S} - [\mathbf{x}]_{i S: (i+1) S} \rVert_2^2\,, \qquad i = 0, 1, \dots, \frac{M}{S} - 1\,,\]

and its memory consumption scales with \(S < M\).

In summary, the computation split works as follows:

  • Compute \(f_{\mathbf{\theta}}(\mathbf{x})\) in a single forward pass.

  • Compute the reconstruction error for a chunk and its GGN in one backward pass.

  • Repeat the last step for the other chunks. Accumulate the GGN diagonals over all chunks.

(This carries over to the mini-batch case in a straightforward fashion. We avoid the presentation here because of the involved notation, though.)

Note that because we perform multiple backward passes, we need to tell PyTorch (and BackPACK) to retain the graph.

To slice out a chunk, we use BackPACK’s Slicing module.

Here is the implementation:

def diag_ggn_multiple_passes(num_chunks: int) -> List[Tensor]:
    """Compute the GGN diagonal in multiple backward passes.

    Uses less memory than ``diag_ggn_one_pass`` if ``num_chunks > 1``.
    Does the same as ``diag_ggn_one_pass`` for ``num_chunks = 1``.

    Args:
        num_chunks: Number of backward passes. Must divide the model's output dimension.

    Returns:
        GGN diagonal in parameter list format.

    Raises:
        ValueError:
            If ``num_chunks`` does not divide the model's output dimension.
        NotImplementedError:
            If the model does not return a batched vector (the slicing logic is only
            implemented for batched vectors, i.e. 2d tensors).
    """
    reconstruction = model(X)

    if reconstruction.numel() % num_chunks != 0:
        raise ValueError("Network output must be divisible by number of chunks.")
    if reconstruction.dim() != 2:
        raise NotImplementedError("Slicing logic only implemented for 2d outputs.")

    chunk_size = reconstruction.shape[1:].numel() // num_chunks
    diag_ggn_exact = [zeros_like(p) for p in model.parameters()]

    for idx in range(num_chunks):
        # set up the layer that extracts the current slice
        slicing = (slice(None), slice(idx * chunk_size, (idx + 1) * chunk_size))
        chunk_module = extend(Slicing(slicing))

        # compute the chunk's loss
        sliced_reconstruction = chunk_module(reconstruction)
        sliced_X = X.flatten(start_dim=1)[slicing]
        slice_error = loss_func(sliced_reconstruction, sliced_X)

        # compute its GGN diagonal ...
        with backpack(DiagGGNExact(), retain_graph=True):
            slice_error.backward(retain_graph=True)

        # ... and accumulate it
        for p_idx, p in enumerate(model.parameters()):
            diag_ggn_exact[p_idx] += p.diag_ggn_exact

    # fix normalization
    return [ggn / num_chunks for ggn in diag_ggn_exact]

Let’s benchmark peak memory and run time for different numbers of chunks:

num_chunks = [1, 4, 16, 64]

for n in num_chunks:
    print(f"GGN diagonal in {n} backward passes:")
    start = time()
    max_mem, diag_ggn_chunk = memory_usage(
        partial(diag_ggn_multiple_passes, n), interval=1e-3, max_usage=True, retval=True
    )
    end = time()
    print(f"\tPeak memory [MiB]: {max_mem:.2e}")
    print(f"\tTime [s]: {end-start:.2e}")

    correct = [
        allclose(diag1, diag2, rtol=5e-3, atol=5e-5)
        for diag1, diag2 in zip(diag_ggn, diag_ggn_chunk)
    ]
    print(f"\tCorrect: {correct}")

    if not all(correct):
        raise RuntimeError("Mismatch in GGN diagonals.")
GGN diagonal in 1 backward passes:
        Peak memory [MiB]: 2.39e+03
        Time [s]: 4.67e+00
        Correct: [True, True, True, True]
GGN diagonal in 4 backward passes:
        Peak memory [MiB]: 1.22e+03
        Time [s]: 4.19e+00
        Correct: [True, True, True, True]
GGN diagonal in 16 backward passes:
        Peak memory [MiB]: 9.22e+02
        Time [s]: 3.53e+00
        Correct: [True, True, True, True]
GGN diagonal in 64 backward passes:
        Peak memory [MiB]: 8.51e+02
        Time [s]: 1.90e+00
        Correct: [True, True, True, True]

We can see that using more chunks consistently decreases the peak memory. Even run time decreases up to a sweet spot where increasing the number of chunks further eventually slows down the computation. The details of this trade-off will depend on your model and compute architecture.

Concluding remarks

Here, we considered chunking the computation along the auto-encoder’s output dimension. There are other ways to achieve the desired effect of reducing peak memory:

  • In the mini-batch setting, we could only consider a subset of mini-batch samples at each backpropagation. This can be done with the optional subsampling argument in many BackPACK’s extensions. See the mini-batch sub-sampling tutorial. This technique can be combined with the above.

  • We could turn off the gradient computation (and thereby BackPACK’s computation) for all but a subgroup of parameters by setting their requires_grad attribute to False and compute the GGN diagonal only for these. However, for this to work we will need to perform a new forward pass for each parameter subgroup.

Total running time of the script: (0 minutes 18.250 seconds)

Gallery generated by Sphinx-Gallery