Hutchinson Trace Estimation

This example illustrates the estimation the Hessian trace of a neural network using Hutchinson’s method [Hutchinson, 1990], which is an algorithm to obtain such an an estimate from matrix-vector products:

$\text{Let } A \in \mathbb{R}^{D \times D} \text{ and } v \in \mathbb{R}^D \text{ be a random vector such that } \mathbb{E}[vv^T] = I. \text{Then,}$
$\mathrm{Tr}(A) = \mathbb{E}[v^TAv] = \frac{1}{V}\sum_{i=1}^{V}v_i^TAv_i.$

We will draw v from a Rademacher Distribution and use Hessian-free multiplication. This can be done with plain autodiff, but note that there is no dependency between sampled vectors, and the Hessian-vector product (HVP) could in principle be performed in parallel. We can use BackPACK’s HMP (Hessian-matrix product) extension to do so, and investigate the potential speedup.

Let’s get the imports and define what a Rademacher distribution is

import time

import matplotlib.pyplot as plt
import torch

from backpack import backpack, extend
from backpack.extensions import HMP, DiagHessian
from backpack.hessianfree.hvp import hessian_vector_product

BATCH_SIZE = 256
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

torch.manual_seed(0)

rand = ((torch.rand(shape) < 0.5)) * 2 - 1
return rand.to(dtype).to(device)


Creating the model and loss

We will use a small NN with 2 linear layers without bias (for a bias of size d, the exact trace can be obtained in d HVPs).

model = torch.nn.Sequential(
torch.nn.Flatten(),
torch.nn.Linear(784, 20, bias=False),
torch.nn.Sigmoid(),
torch.nn.Linear(20, 10, bias=False),
).to(DEVICE)
model = extend(model)

loss_function = torch.nn.CrossEntropyLoss().to(DEVICE)
loss_function = extend(loss_function)


In the following, we load a batch from MNIST, compute the loss and trigger the backward pass with(backpack(..)) such that we have access to the extensions that we are going to use (DiagHessian and HMP)).

x, y = load_one_batch_mnist(BATCH_SIZE)
x, y = x.to(DEVICE), y.to(DEVICE)

def forward_backward_with_backpack():
"""Provide working access to BackPACK's DiagHessian and HMP."""
loss = loss_function(model(x), y)

with backpack(DiagHessian(), HMP()):
# keep graph for autodiff HVPs
loss.backward(retain_graph=True)

return loss

loss = forward_backward_with_backpack()


Exact trace computation

To make sure our implementation is fine, and to develop a feeling for the Hutchinson estimator quality, let’s compute the exact trace by summing up the Hessian diagonal.

def exact_trace():
"""Exact trace from sum of Hessian diagonal."""
param_trace = [p.diag_h.sum().item() for p in model.parameters()]
return sum(param_trace)

print("Exact trace: {:.3f}".format(exact_trace()))

Exact trace: 12.455


Trace estimation (BackPACK’s HMP)

BackPACK’s HMP extension gives access to multiplication with the parameter Hessian, which is one diagonal block in the full Hessian whose trace we want to estimate. The multiplication can even handle multiple vectors at a time. Here is the implementation. The computation of V HVPs, which might exceed our available memory, is chunked into batches of size V_batch.

def hutchinson_trace_hmp(V, V_batch=1):
"""Hessian trace estimate using BackPACK's HMP extension.

Perform V_batch Hessian multiplications at a time.
"""
V_count = 0
trace = 0

while V_count < V:
V_missing = V - V_count
V_next = min(V_batch, V_missing)

for param in model.parameters():
Hv = param.hmp(v).detach()
vHv = torch.einsum("i,i->", v.flatten(), Hv.flatten().detach())
trace += vHv / V

V_count += V_next

return trace

print(
"Trace estimate via BackPACK's HMP extension: {:.3f}".format(
hutchinson_trace_hmp(V=1000, V_batch=10)
)
)

Trace estimate via BackPACK's HMP extension: 12.308


Trace estimation (autodiff, full Hessian)

We can also use autodiff tricks to compute a single HVP at a time, provided by utility function hessian_vector_product in BackPACK. Here is the implementation, and a test:

def hutchinson_trace_autodiff(V):
"""Hessian trace estimate using autodiff HVPs."""
trace = 0

for _ in range(V):
vec = [rademacher(p.shape) for p in model.parameters()]
Hvec = hessian_vector_product(loss, list(model.parameters()), vec)

for v, Hv in zip(vec, Hvec):
vHv = torch.einsum("i,i->", v.flatten(), Hv.flatten())
trace += vHv / V

return trace

print(
"Trace estimate via PyTorch's HVP: {:.3f}".format(hutchinson_trace_autodiff(V=1000))
)

Trace estimate via PyTorch's HVP: 12.142


Trace estimation (autodiff, block-diagonal Hessian)

Since HMP uses only the Hessian block-diagonal and not the full Hessian, here is the corresponding autodiff implementation using the same matrix as HMP. We are going to reinvestigate it for benchmarking.

def hutchinson_trace_autodiff_blockwise(V):
"""Hessian trace estimate using autodiff block HVPs."""
trace = 0

for _ in range(V):
for p in model.parameters():
Hv = hessian_vector_product(loss, [p], v)
vHv = torch.einsum("i,i->", v[0].flatten(), Hv[0].flatten())

trace += vHv / V

return trace

print(
"Trace estimate via PyTorch's blockwise HVP: {:.3f}".format(
hutchinson_trace_autodiff_blockwise(V=1000)
)
)

# restore BackPACK IO, which is deleted by autodiff HVP
loss = forward_backward_with_backpack()

Trace estimate via PyTorch's blockwise HVP: 12.581


Trace approximation accuracy

Next, let’s observe how the approximation improves with the number of samples. Here, we plot multiple runs of the Hutchinson trace estimate, initialized at different random seeds.

V_steps = 30
V_list = torch.logspace(1, 3, steps=V_steps).int()
V_batch = 10
num_curves = 15

fig = plt.figure(figsize=(20, 10))
plt.xlabel("Number of Samples")
plt.ylabel("Trace")
plt.semilogx(V_list, V_steps * [exact_trace()], color="blue", label="Exact")

for i in range(num_curves):
trace_estimates = []

for V in V_list:
torch.manual_seed(i)
trace_estimates.append(hutchinson_trace_hmp(V, V_batch))

plt.semilogx(
V_list,
[trace_estimate.cpu() for trace_estimate in trace_estimates],
linestyle="--",
color="orange",
label="Hutchinson" if i == 0 else None,
)

_ = plt.legend()


Runtime comparison

Finally, we investigate if the trace estimation is sped up by vectorizing the HVPs. In particular, let’s compare the estimations using autodiff HVPs (no parallelization), autodiff block-diagonal HVPs (no parallelization) and block-diagonal vectorized HVPs (HMP).

V = 1000

def time_hutchinson_trace_autodiff(V):
start = time.time()
trace = hutchinson_trace_autodiff(V)
end = time.time()
duration = end - start
print(
"Estim. trace: {:.3f}, time {:.3f}, (autodiff, full HVP)".format(
trace, duration
)
)

def time_hutchinson_trace_autodiff_blockwise(V):
start = time.time()
trace = hutchinson_trace_autodiff_blockwise(V)
end = time.time()
duration = end - start
print(
"Estim. trace: {:.3f}, time {:.3f}, (autodiff, block HVP)".format(
trace, duration
)
)

def time_hutchinson_trace_hmp(V, V_batch):
start = time.time()
trace = hutchinson_trace_hmp(V, V_batch)
end = time.time()
duration = end - start
print(
"Estim. trace: {:.3f}, time {:.3f}, (BackPACK, V_batch={} block HVP)".format(
trace, duration, V_batch
)
)

print("Exact trace:  {:.3f}".format(exact_trace()))
time_hutchinson_trace_autodiff(V)
time_hutchinson_trace_autodiff_blockwise(V)
# restore BackPACK IO, which is deleted by autodiff HVP
loss = forward_backward_with_backpack()
time_hutchinson_trace_hmp(V, V_batch=5)
time_hutchinson_trace_hmp(V, V_batch=10)
time_hutchinson_trace_hmp(V, V_batch=20)

Exact trace:  12.455
Estim. trace: 12.373, time 2.778, (autodiff, full HVP)
Estim. trace: 12.095, time 3.320, (autodiff, block HVP)
Estim. trace: 12.465, time 1.133, (BackPACK, V_batch=5 block HVP)
Estim. trace: 12.298, time 0.747, (BackPACK, V_batch=10 block HVP)
Estim. trace: 12.345, time 0.555, (BackPACK, V_batch=20 block HVP)


Looks like the parallel Hessian-vector products are able to speed up the computation. Nice.

Note that instead of the Hessian, we could have also used other interesting matrices, such as the generalized Gauss-Newton. BackPACK also offers a vectorized multiplication with the latter’s block-diagonal (see the GGNMP extension).

Total running time of the script: ( 1 minutes 32.871 seconds)

Gallery generated by Sphinx-Gallery