Extension hook example
The extension hook is a function called on each module after the BackPACK extensions have run. It can be used to reduce memory overhead if the goal is to compute transformations of BackPACK quantities. Information can be compacted during a backward pass and obsolete tensors be freed manually.
Here, we use it to compute the Hessian trace after each module and free the memory used to store the diagonal Hessian to reduce peak memory load.
Let’s start by loading some dummy data and extending the model
import torch from torch.nn import CrossEntropyLoss, Flatten, Linear, Sequential from backpack import backpack, extend from backpack.extensions import DiagHessian from backpack.utils.examples import load_one_batch_mnist # make deterministic torch.manual_seed(0) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # data X, y = load_one_batch_mnist(batch_size=128) X, y = X.to(device), y.to(device) # model model = Sequential(Flatten(), Linear(784, 10)).to(device) lossfunc = CrossEntropyLoss().to(device) model = extend(model) lossfunc = extend(lossfunc)
Standard computation of the trace
Tr(H) after backward: 690.637
Let’s clean up the computation graph and existing BackPACK buffers
The extension hook is a function that takes a
torch.nn.Module (and returns
None). It is executed on each module after the BackPACK extensions have run.
We use an object to store information from all modules. The hook will compute the trace of the Hessian for the block of parameters associated with the module and mark the tensors storing the diagonal Hessian to be freed.
class TraceHook: def __init__(self): """BackPACK extension hook that sums up the Hessian diagonal on the fly.""" self.value = 0.0 def sum_diag_h(self, module): """Sum ``value`` attribute with the diagonal Hessian elements.""" for param in module.parameters(): if hasattr(param, "diag_h"): self.value += param.diag_h.sum() delattr(param, "diag_h") tr_hook = TraceHook()
Hook computation of the trace
Tr(H) while backward: 690.637 Same Tr(H)? True
On memory usage
del functions do not directly free emory, but mark
the tensor to be garbage collected by Python (as long as there are no other
reference to the tensor.
For the diagonal Hessian, the memory savings are rather small, as it has the same size as the gradient. For quantities that scale with batch and model size, like individual gradients the extension hook might make it possible to fit the computation in RAM where it would not be possible otherwise.
Total running time of the script: ( 0 minutes 0.113 seconds)