Note
Go to the end to download the full example code.
BackPACK’s retain_graph option
This tutorial demonstrates how to perform multiple backward passes through the same computation graph with BackPACK. This option can be useful if you run into out-of-memory errors. If your computation can be chunked, you might consider distributing it onto multiple backward passes to reduce peak memory.
Our use case for such a quantity is the GGN diagonal of an auto-encoder’s reconstruction error.
But first, the imports:
from functools import partial
from time import time
from typing import List
from memory_profiler import memory_usage
from torch import Tensor, allclose, manual_seed, rand, zeros_like
from torch.nn import Conv2d, ConvTranspose2d, Flatten, MSELoss, Sequential, Sigmoid
from backpack import backpack, extend
from backpack.custom_module.slicing import Slicing
from backpack.extensions import DiagGGNExact
# make deterministic
manual_seed(0)
<torch._C.Generator object at 0x7fa2b03d9fd0>
Setup
Let \(f_{\mathbf{\theta}}\) denote the auto-encoder, and \(\mathbf{x'} = f_{\mathbf{\theta}}(\mathbf{x}) \in \mathbb{R}^M\) its reconstruction of an input \(\mathbf{x} \in \mathbb{R}^M\). The associated reconstruction error is measured by the mean squared error
On a batch of \(N\) examples, \(\mathbf{x}_1, \dots, \mathbf{x}_N\), the loss is
Let’s create a toy model and data:
# data
batch_size, channels, spatial_dims = 5, 3, (32, 32)
X = rand(batch_size, channels, *spatial_dims)
# model (auto-encoder)
hidden_channels = 10
encoder = Sequential(
Conv2d(channels, hidden_channels, 3),
Sigmoid(),
)
decoder = Sequential(
ConvTranspose2d(hidden_channels, channels, 3),
Flatten(),
)
model = Sequential(
encoder,
decoder,
)
loss_func = MSELoss()
We will use BackPACK to compute the GGN diagonal of the mini-batch loss. To
do that, we need to extend
model and loss
function.
GGN diagonal in one backward pass
As usual, we can compute the GGN diagonal for the mini-batch loss in a single backward pass. The following function does that:
def diag_ggn_one_pass() -> List[Tensor]:
"""Compute the GGN diagonal in a single backward pass.
Returns:
GGN diagonal in parameter list format.
"""
reconstruction = model(X)
error = loss_func(reconstruction, X.flatten(start_dim=1))
with backpack(DiagGGNExact()):
error.backward()
return [p.diag_ggn_exact.clone() for p in model.parameters() if p.requires_grad]
Let’s run it and determine (i) its peak memory consumption and (ii) its run time.
print("GGN diagonal in one backward pass:")
start = time()
max_mem, diag_ggn = memory_usage(
diag_ggn_one_pass, interval=1e-3, max_usage=True, retval=True
)
end = time()
print(f"\tPeak memory [MiB]: {max_mem:.2e}")
print(f"\tTime [s]: {end-start:.2e}")
GGN diagonal in one backward pass:
Peak memory [MiB]: 2.49e+03
Time [s]: 4.00e+00
The memory consumption is pretty high, although our model is relatively small! If we make the model deeper, or increase the mini-batch size, we will quickly run out of memory.
This is because computing the GGN diagonal scales with the network’s output
dimension. For classification settings like MNIST and CIFAR-10, this number
is relatively small (10
). But for an auto-encoder, this number is the
input dimension M
, which in our case is
Output dimension: 3072
We will now take a look at how to circumvent the high peak memory by distributing the computation over multiple backward passes.
GGN diagonal in chunks
The GGN diagonal computation can be distributed across multiple backward passes. This greatly reduces peak memory consumption.
To see this, let’s consider the GGN diagonal for a single example \(\mathbf{x}\),
with the \(M \times |\mathbf{\theta}|\) Jacobian
\(\mathbf{J}_{\mathbf{\theta}} f_{\mathbf{\theta}}(\mathbf{x})\) of the
model, and \(\frac{2}{M} \mathbf{I}_{M\times M}\) the mean squared
error’s Hessian w.r.t. the reconstructed input. Here you can see that the
memory consumption scales with the output dimension, as we need to compute
M
vector-Jacobian products.
Let \(S\), the chunk size, be a number that divides the output dimension \(M\). Then, we can decompose the above computation into chunks:
Each summand is the GGN diagonal of the mean squared error on a chunk
and its memory consumption scales with \(S < M\).
In summary, the computation split works as follows:
Compute \(f_{\mathbf{\theta}}(\mathbf{x})\) in a single forward pass.
Compute the reconstruction error for a chunk and its GGN in one backward pass.
Repeat the last step for the other chunks. Accumulate the GGN diagonals over all chunks.
(This carries over to the mini-batch case in a straightforward fashion. We avoid the presentation here because of the involved notation, though.)
Note that because we perform multiple backward passes, we need to tell PyTorch (and BackPACK) to retain the graph.
To slice out a chunk, we use BackPACK’s Slicing
module.
Here is the implementation:
def diag_ggn_multiple_passes(num_chunks: int) -> List[Tensor]:
"""Compute the GGN diagonal in multiple backward passes.
Uses less memory than ``diag_ggn_one_pass`` if ``num_chunks > 1``.
Does the same as ``diag_ggn_one_pass`` for ``num_chunks = 1``.
Args:
num_chunks: Number of backward passes. Must divide the model's output dimension.
Returns:
GGN diagonal in parameter list format.
Raises:
ValueError:
If ``num_chunks`` does not divide the model's output dimension.
NotImplementedError:
If the model does not return a batched vector (the slicing logic is only
implemented for batched vectors, i.e. 2d tensors).
"""
reconstruction = model(X)
if reconstruction.numel() % num_chunks != 0:
raise ValueError("Network output must be divisible by number of chunks.")
if reconstruction.dim() != 2:
raise NotImplementedError("Slicing logic only implemented for 2d outputs.")
chunk_size = reconstruction.shape[1:].numel() // num_chunks
diag_ggn_exact = [zeros_like(p) for p in model.parameters()]
for idx in range(num_chunks):
# set up the layer that extracts the current slice
slicing = (slice(None), slice(idx * chunk_size, (idx + 1) * chunk_size))
chunk_module = extend(Slicing(slicing))
# compute the chunk's loss
sliced_reconstruction = chunk_module(reconstruction)
sliced_X = X.flatten(start_dim=1)[slicing]
slice_error = loss_func(sliced_reconstruction, sliced_X)
# compute its GGN diagonal ...
with backpack(DiagGGNExact(), retain_graph=True):
slice_error.backward(retain_graph=True)
# ... and accumulate it
for p_idx, p in enumerate(model.parameters()):
diag_ggn_exact[p_idx] += p.diag_ggn_exact
# fix normalization
return [ggn / num_chunks for ggn in diag_ggn_exact]
Let’s benchmark peak memory and run time for different numbers of chunks:
num_chunks = [1, 4, 16, 64]
for n in num_chunks:
print(f"GGN diagonal in {n} backward passes:")
start = time()
max_mem, diag_ggn_chunk = memory_usage(
partial(diag_ggn_multiple_passes, n), interval=1e-3, max_usage=True, retval=True
)
end = time()
print(f"\tPeak memory [MiB]: {max_mem:.2e}")
print(f"\tTime [s]: {end-start:.2e}")
correct = [
allclose(diag1, diag2, rtol=5e-3, atol=5e-5)
for diag1, diag2 in zip(diag_ggn, diag_ggn_chunk)
]
print(f"\tCorrect: {correct}")
if not all(correct):
raise RuntimeError("Mismatch in GGN diagonals.")
GGN diagonal in 1 backward passes:
Peak memory [MiB]: 2.47e+03
Time [s]: 4.76e+00
Correct: [True, True, True, True]
GGN diagonal in 4 backward passes:
Peak memory [MiB]: 1.31e+03
Time [s]: 4.17e+00
Correct: [True, True, True, True]
GGN diagonal in 16 backward passes:
Peak memory [MiB]: 1.01e+03
Time [s]: 3.51e+00
Correct: [True, True, True, True]
GGN diagonal in 64 backward passes:
Peak memory [MiB]: 9.20e+02
Time [s]: 1.83e+00
Correct: [True, True, True, True]
We can see that using more chunks consistently decreases the peak memory. Even run time decreases up to a sweet spot where increasing the number of chunks further eventually slows down the computation. The details of this trade-off will depend on your model and compute architecture.
Concluding remarks
Here, we considered chunking the computation along the auto-encoder’s output dimension. There are other ways to achieve the desired effect of reducing peak memory:
In the mini-batch setting, we could only consider a subset of mini-batch samples at each backpropagation. This can be done with the optional
subsampling
argument in many BackPACK’s extensions. See the mini-batch sub-sampling tutorial. This technique can be combined with the above.We could turn off the gradient computation (and thereby BackPACK’s computation) for all but a subgroup of parameters by setting their
requires_grad
attribute toFalse
and compute the GGN diagonal only for these. However, for this to work we will need to perform a new forward pass for each parameter subgroup.
Total running time of the script: (0 minutes 18.282 seconds)