Note

Go to the end to download the full example code

# Example using all extensions

Basic example showing how compute the gradient, and and other quantities with BackPACK, on a linear model for MNIST.

Let’s start by loading some dummy data and extending the model

```
from torch import rand
from torch.nn import CrossEntropyLoss, Flatten, Linear, Sequential
from backpack import backpack, extend
from backpack.extensions import (
GGNMP,
HMP,
KFAC,
KFLR,
KFRA,
PCHMP,
BatchDiagGGNExact,
BatchDiagGGNMC,
BatchDiagHessian,
BatchGrad,
BatchL2Grad,
DiagGGNExact,
DiagGGNMC,
DiagHessian,
SqrtGGNExact,
SqrtGGNMC,
SumGradSquared,
Variance,
)
from backpack.utils.examples import load_one_batch_mnist
X, y = load_one_batch_mnist(batch_size=512)
model = Sequential(Flatten(), Linear(784, 10))
lossfunc = CrossEntropyLoss()
model = extend(model)
lossfunc = extend(lossfunc)
```

```
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz
0.3%
0.7%
1.0%
1.3%
1.7%
2.0%
2.3%
2.6%
3.0%
3.3%
3.6%
4.0%
4.3%
4.6%
5.0%
5.3%
5.6%
6.0%
6.3%
6.6%
6.9%
7.3%
7.6%
7.9%
8.3%
8.6%
8.9%
9.3%
9.6%
9.9%
10.2%
10.6%
10.9%
11.2%
11.6%
11.9%
12.2%
12.6%
12.9%
13.2%
13.6%
13.9%
14.2%
14.5%
14.9%
15.2%
15.5%
15.9%
16.2%
16.5%
16.9%
17.2%
17.5%
17.9%
18.2%
18.5%
18.8%
19.2%
19.5%
19.8%
20.2%
20.5%
20.8%
21.2%
21.5%
21.8%
22.1%
22.5%
22.8%
23.1%
23.5%
23.8%
24.1%
24.5%
24.8%
25.1%
25.5%
25.8%
26.1%
26.4%
26.8%
27.1%
27.4%
27.8%
28.1%
28.4%
28.8%
29.1%
29.4%
29.8%
30.1%
30.4%
30.7%
31.1%
31.4%
31.7%
32.1%
32.4%
32.7%
33.1%
33.4%
33.7%
34.0%
34.4%
34.7%
35.0%
35.4%
35.7%
36.0%
36.4%
36.7%
37.0%
37.4%
37.7%
38.0%
38.3%
38.7%
39.0%
39.3%
39.7%
40.0%
40.3%
40.7%
41.0%
41.3%
41.7%
42.0%
42.3%
42.6%
43.0%
43.3%
43.6%
44.0%
44.3%
44.6%
45.0%
45.3%
45.6%
45.9%
46.3%
46.6%
46.9%
47.3%
47.6%
47.9%
48.3%
48.6%
48.9%
49.3%
49.6%
49.9%
50.2%
50.6%
50.9%
51.2%
51.6%
51.9%
52.2%
52.6%
52.9%
53.2%
53.6%
53.9%
54.2%
54.5%
54.9%
55.2%
55.5%
55.9%
56.2%
56.5%
56.9%
57.2%
57.5%
57.9%
58.2%
58.5%
58.8%
59.2%
59.5%
59.8%
60.2%
60.5%
60.8%
61.2%
61.5%
61.8%
62.1%
62.5%
62.8%
63.1%
63.5%
63.8%
64.1%
64.5%
64.8%
65.1%
65.5%
65.8%
66.1%
66.4%
66.8%
67.1%
67.4%
67.8%
68.1%
68.4%
68.8%
69.1%
69.4%
69.8%
70.1%
70.4%
70.7%
71.1%
71.4%
71.7%
72.1%
72.4%
72.7%
73.1%
73.4%
73.7%
74.0%
74.4%
74.7%
75.0%
75.4%
75.7%
76.0%
76.4%
76.7%
77.0%
77.4%
77.7%
78.0%
78.3%
78.7%
79.0%
79.3%
79.7%
80.0%
80.3%
80.7%
81.0%
81.3%
81.7%
82.0%
82.3%
82.6%
83.0%
83.3%
83.6%
84.0%
84.3%
84.6%
85.0%
85.3%
85.6%
85.9%
86.3%
86.6%
86.9%
87.3%
87.6%
87.9%
88.3%
88.6%
88.9%
89.3%
89.6%
89.9%
90.2%
90.6%
90.9%
91.2%
91.6%
91.9%
92.2%
92.6%
92.9%
93.2%
93.6%
93.9%
94.2%
94.5%
94.9%
95.2%
95.5%
95.9%
96.2%
96.5%
96.9%
97.2%
97.5%
97.9%
98.2%
98.5%
98.8%
99.2%
99.5%
99.8%
100.0%
Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz
100.0%
Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz
2.0%
4.0%
6.0%
7.9%
9.9%
11.9%
13.9%
15.9%
17.9%
19.9%
21.9%
23.8%
25.8%
27.8%
29.8%
31.8%
33.8%
35.8%
37.8%
39.7%
41.7%
43.7%
45.7%
47.7%
49.7%
51.7%
53.7%
55.6%
57.6%
59.6%
61.6%
63.6%
65.6%
67.6%
69.6%
71.5%
73.5%
75.5%
77.5%
79.5%
81.5%
83.5%
85.5%
87.4%
89.4%
91.4%
93.4%
95.4%
97.4%
99.4%
100.0%
Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz
100.0%
Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw
```

## First order extensions

Batch gradients

```
loss = lossfunc(model(X), y)
with backpack(BatchGrad()):
loss.backward()
for name, param in model.named_parameters():
print(name)
print(".grad.shape: ", param.grad.shape)
print(".grad_batch.shape: ", param.grad_batch.shape)
```

```
1.weight
.grad.shape: torch.Size([10, 784])
.grad_batch.shape: torch.Size([512, 10, 784])
1.bias
.grad.shape: torch.Size([10])
.grad_batch.shape: torch.Size([512, 10])
```

Variance

```
loss = lossfunc(model(X), y)
with backpack(Variance()):
loss.backward()
for name, param in model.named_parameters():
print(name)
print(".grad.shape: ", param.grad.shape)
print(".variance.shape: ", param.variance.shape)
```

```
1.weight
.grad.shape: torch.Size([10, 784])
.variance.shape: torch.Size([10, 784])
1.bias
.grad.shape: torch.Size([10])
.variance.shape: torch.Size([10])
```

Second moment/sum of gradients squared

```
loss = lossfunc(model(X), y)
with backpack(SumGradSquared()):
loss.backward()
for name, param in model.named_parameters():
print(name)
print(".grad.shape: ", param.grad.shape)
print(".sum_grad_squared.shape: ", param.sum_grad_squared.shape)
```

```
1.weight
.grad.shape: torch.Size([10, 784])
.sum_grad_squared.shape: torch.Size([10, 784])
1.bias
.grad.shape: torch.Size([10])
.sum_grad_squared.shape: torch.Size([10])
```

L2 norm of individual gradients

```
loss = lossfunc(model(X), y)
with backpack(BatchL2Grad()):
loss.backward()
for name, param in model.named_parameters():
print(name)
print(".grad.shape: ", param.grad.shape)
print(".batch_l2.shape: ", param.batch_l2.shape)
```

```
1.weight
.grad.shape: torch.Size([10, 784])
.batch_l2.shape: torch.Size([512])
1.bias
.grad.shape: torch.Size([10])
.batch_l2.shape: torch.Size([512])
```

It’s also possible to ask for multiple quantities at once

```
loss = lossfunc(model(X), y)
with backpack(BatchGrad(), Variance(), SumGradSquared(), BatchL2Grad()):
loss.backward()
for name, param in model.named_parameters():
print(name)
print(".grad.shape: ", param.grad.shape)
print(".grad_batch.shape: ", param.grad_batch.shape)
print(".variance.shape: ", param.variance.shape)
print(".sum_grad_squared.shape: ", param.sum_grad_squared.shape)
print(".batch_l2.shape: ", param.batch_l2.shape)
```

```
1.weight
.grad.shape: torch.Size([10, 784])
.grad_batch.shape: torch.Size([512, 10, 784])
.variance.shape: torch.Size([10, 784])
.sum_grad_squared.shape: torch.Size([10, 784])
.batch_l2.shape: torch.Size([512])
1.bias
.grad.shape: torch.Size([10])
.grad_batch.shape: torch.Size([512, 10])
.variance.shape: torch.Size([10])
.sum_grad_squared.shape: torch.Size([10])
.batch_l2.shape: torch.Size([512])
```

## Second order extensions

Diagonal of the generalized Gauss-Newton and its Monte-Carlo approximation

```
loss = lossfunc(model(X), y)
with backpack(DiagGGNExact(), DiagGGNMC(mc_samples=1)):
loss.backward()
for name, param in model.named_parameters():
print(name)
print(".grad.shape: ", param.grad.shape)
print(".diag_ggn_mc.shape: ", param.diag_ggn_mc.shape)
print(".diag_ggn_exact.shape: ", param.diag_ggn_exact.shape)
```

```
1.weight
.grad.shape: torch.Size([10, 784])
.diag_ggn_mc.shape: torch.Size([10, 784])
.diag_ggn_exact.shape: torch.Size([10, 784])
1.bias
.grad.shape: torch.Size([10])
.diag_ggn_mc.shape: torch.Size([10])
.diag_ggn_exact.shape: torch.Size([10])
```

Per-sample diagonal of the generalized Gauss-Newton and its Monte-Carlo approximation

```
loss = lossfunc(model(X), y)
with backpack(BatchDiagGGNExact(), BatchDiagGGNMC(mc_samples=1)):
loss.backward()
for name, param in model.named_parameters():
print(name)
print(".diag_ggn_mc_batch.shape: ", param.diag_ggn_mc_batch.shape)
print(".diag_ggn_exact_batch.shape: ", param.diag_ggn_exact_batch.shape)
```

```
1.weight
.diag_ggn_mc_batch.shape: torch.Size([512, 10, 784])
.diag_ggn_exact_batch.shape: torch.Size([512, 10, 784])
1.bias
.diag_ggn_mc_batch.shape: torch.Size([512, 10])
.diag_ggn_exact_batch.shape: torch.Size([512, 10])
```

KFAC, KFRA and KFLR

```
loss = lossfunc(model(X), y)
with backpack(KFAC(mc_samples=1), KFLR(), KFRA()):
loss.backward()
for name, param in model.named_parameters():
print(name)
print(".grad.shape: ", param.grad.shape)
print(".kfac (shapes): ", [kfac.shape for kfac in param.kfac])
print(".kflr (shapes): ", [kflr.shape for kflr in param.kflr])
print(".kfra (shapes): ", [kfra.shape for kfra in param.kfra])
```

```
1.weight
.grad.shape: torch.Size([10, 784])
.kfac (shapes): [torch.Size([10, 10]), torch.Size([784, 784])]
.kflr (shapes): [torch.Size([10, 10]), torch.Size([784, 784])]
.kfra (shapes): [torch.Size([10, 10]), torch.Size([784, 784])]
1.bias
.grad.shape: torch.Size([10])
.kfac (shapes): [torch.Size([10, 10])]
.kflr (shapes): [torch.Size([10, 10])]
.kfra (shapes): [torch.Size([10, 10])]
```

Diagonal Hessian and per-sample diagonal Hessian

```
1.weight
.grad.shape: torch.Size([10, 784])
.diag_h.shape: torch.Size([10, 784])
.diag_h_batch.shape: torch.Size([512, 10, 784])
1.bias
.grad.shape: torch.Size([10])
.diag_h.shape: torch.Size([10])
.diag_h_batch.shape: torch.Size([512, 10])
```

Matrix square root of the generalized Gauss-Newton or its Monte-Carlo approximation

```
loss = lossfunc(model(X), y)
with backpack(SqrtGGNExact(), SqrtGGNMC(mc_samples=1)):
loss.backward()
for name, param in model.named_parameters():
print(name)
print(".grad.shape: ", param.grad.shape)
print(".sqrt_ggn_exact.shape: ", param.sqrt_ggn_exact.shape)
print(".sqrt_ggn_mc.shape: ", param.sqrt_ggn_mc.shape)
```

```
1.weight
.grad.shape: torch.Size([10, 784])
.sqrt_ggn_exact.shape: torch.Size([10, 512, 10, 784])
.sqrt_ggn_mc.shape: torch.Size([1, 512, 10, 784])
1.bias
.grad.shape: torch.Size([10])
.sqrt_ggn_exact.shape: torch.Size([10, 512, 10])
.sqrt_ggn_mc.shape: torch.Size([1, 512, 10])
```

## Block-diagonal curvature products

Curvature-matrix product (`MP`

) extensions provide functions
that multiply with the block diagonal of different curvature matrices, such as

the Hessian (

`HMP`

)the generalized Gauss-Newton (

`GGNMP`

)the positive-curvature Hessian (

`PCHMP`

)

Multiply a random vector with curvature blocks.

```
V = 1
for name, param in model.named_parameters():
vec = rand(V, *param.shape)
print(name)
print(".grad.shape: ", param.grad.shape)
print("vec.shape: ", vec.shape)
print(".hmp(vec).shape: ", param.hmp(vec).shape)
print(".ggnmp(vec).shape: ", param.ggnmp(vec).shape)
print(".pchmp_clip(vec).shape: ", param.pchmp_clip(vec).shape)
print(".pchmp_abs(vec).shape: ", param.pchmp_abs(vec).shape)
```

```
1.weight
.grad.shape: torch.Size([10, 784])
vec.shape: torch.Size([1, 10, 784])
.hmp(vec).shape: torch.Size([1, 10, 784])
.ggnmp(vec).shape: torch.Size([1, 10, 784])
.pchmp_clip(vec).shape: torch.Size([1, 10, 784])
.pchmp_abs(vec).shape: torch.Size([1, 10, 784])
1.bias
.grad.shape: torch.Size([10])
vec.shape: torch.Size([1, 10])
.hmp(vec).shape: torch.Size([1, 10])
.ggnmp(vec).shape: torch.Size([1, 10])
.pchmp_clip(vec).shape: torch.Size([1, 10])
.pchmp_abs(vec).shape: torch.Size([1, 10])
```

Multiply a collection of three vectors (a matrix) with curvature blocks.

```
V = 3
for name, param in model.named_parameters():
vec = rand(V, *param.shape)
print(name)
print(".grad.shape: ", param.grad.shape)
print("vec.shape: ", vec.shape)
print(".hmp(vec).shape: ", param.hmp(vec).shape)
print(".ggnmp(vec).shape: ", param.ggnmp(vec).shape)
print(".pchmp_clip(vec).shape: ", param.pchmp_clip(vec).shape)
print(".pchmp_abs(vec).shape: ", param.pchmp_abs(vec).shape)
```

```
1.weight
.grad.shape: torch.Size([10, 784])
vec.shape: torch.Size([3, 10, 784])
.hmp(vec).shape: torch.Size([3, 10, 784])
.ggnmp(vec).shape: torch.Size([3, 10, 784])
.pchmp_clip(vec).shape: torch.Size([3, 10, 784])
.pchmp_abs(vec).shape: torch.Size([3, 10, 784])
1.bias
.grad.shape: torch.Size([10])
vec.shape: torch.Size([3, 10])
.hmp(vec).shape: torch.Size([3, 10])
.ggnmp(vec).shape: torch.Size([3, 10])
.pchmp_clip(vec).shape: torch.Size([3, 10])
.pchmp_abs(vec).shape: torch.Size([3, 10])
```

**Total running time of the script:** ( 0 minutes 3.443 seconds)