Note
Go to the end to download the full example code.
Extension hook example
The extension hook is a function called on each module after the BackPACK extensions have run. It can be used to reduce memory overhead if the goal is to compute transformations of BackPACK quantities. Information can be compacted during a backward pass and obsolete tensors be freed manually.
Here, we use it to compute the Hessian trace after each module and free the memory used to store the diagonal Hessian to reduce peak memory load.
Let’s start by loading some dummy data and extending the model
import torch
from torch.nn import CrossEntropyLoss, Flatten, Linear, Sequential
from backpack import backpack, extend
from backpack.extensions import DiagHessian
from backpack.utils.examples import load_one_batch_mnist
# make deterministic
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# data
X, y = load_one_batch_mnist(batch_size=128)
X, y = X.to(device), y.to(device)
# model
model = Sequential(Flatten(), Linear(784, 10)).to(device)
lossfunc = CrossEntropyLoss().to(device)
model = extend(model)
lossfunc = extend(lossfunc)
Standard computation of the trace
loss = lossfunc(model(X), y)
with backpack(DiagHessian()):
loss.backward()
tr_after_backward = sum(param.diag_h.sum() for param in model.parameters())
print(f"Tr(H) after backward: {tr_after_backward:.3f} ")
Tr(H) after backward: 690.637
Let’s clean up the computation graph and existing BackPACK buffers
del loss
for param in model.parameters():
del param.diag_h
Extension hook
The extension hook is a function that takes a torch.nn.Module
(and returns
None
). It is executed on each module after the BackPACK extensions have run.
We use an object to store information from all modules. The hook will compute the trace of the Hessian for the block of parameters associated with the module and mark the tensors storing the diagonal Hessian to be freed.
class TraceHook:
def __init__(self):
"""BackPACK extension hook that sums up the Hessian diagonal on the fly."""
self.value = 0.0
def sum_diag_h(self, module):
"""Sum ``value`` attribute with the diagonal Hessian elements."""
for param in module.parameters():
if hasattr(param, "diag_h"):
self.value += param.diag_h.sum()
delattr(param, "diag_h")
tr_hook = TraceHook()
Hook computation of the trace
loss = lossfunc(model(X), y)
with backpack(DiagHessian(), extension_hook=tr_hook.sum_diag_h):
loss.backward()
tr_while_backward = tr_hook.value
print(f"Tr(H) while backward: {tr_while_backward:.3f}")
print(f"Same Tr(H)? {torch.allclose(tr_after_backward, tr_while_backward)}")
Tr(H) while backward: 690.637
Same Tr(H)? True
On memory usage
The delattr
and del
functions do not directly free emory, but mark
the tensor to be garbage collected by Python (as long as there are no other
reference to the tensor.
For the diagonal Hessian, the memory savings are rather small, as it has the same size as the gradient. For quantities that scale with batch and model size, like individual gradients the extension hook might make it possible to fit the computation in RAM where it would not be possible otherwise.
Total running time of the script: (0 minutes 0.104 seconds)