Note
Go to the end to download the full example code.
Example using all extensions
Basic example showing how to compute the gradient, and other quantities with BackPACK, on a linear model for MNIST.
Let’s start by loading some dummy data and extending the model
from torch import rand
from torch.nn import CrossEntropyLoss, Flatten, Linear, Sequential
from backpack import backpack, extend
from backpack.extensions import (
GGNMP,
HMP,
KFAC,
KFLR,
KFRA,
PCHMP,
BatchDiagGGNExact,
BatchDiagGGNMC,
BatchDiagHessian,
BatchGrad,
BatchL2Grad,
DiagGGNExact,
DiagGGNMC,
DiagHessian,
SqrtGGNExact,
SqrtGGNMC,
SumGradSquared,
Variance,
)
from backpack.utils.examples import load_one_batch_mnist
X, y = load_one_batch_mnist(batch_size=512)
model = Sequential(Flatten(), Linear(784, 10))
lossfunc = CrossEntropyLoss()
model = extend(model)
lossfunc = extend(lossfunc)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz
0.3%
0.7%
1.0%
1.3%
1.7%
2.0%
2.3%
2.6%
3.0%
3.3%
3.6%
4.0%
4.3%
4.6%
5.0%
5.3%
5.6%
6.0%
6.3%
6.6%
6.9%
7.3%
7.6%
7.9%
8.3%
8.6%
8.9%
9.3%
9.6%
9.9%
10.2%
10.6%
10.9%
11.2%
11.6%
11.9%
12.2%
12.6%
12.9%
13.2%
13.6%
13.9%
14.2%
14.5%
14.9%
15.2%
15.5%
15.9%
16.2%
16.5%
16.9%
17.2%
17.5%
17.9%
18.2%
18.5%
18.8%
19.2%
19.5%
19.8%
20.2%
20.5%
20.8%
21.2%
21.5%
21.8%
22.1%
22.5%
22.8%
23.1%
23.5%
23.8%
24.1%
24.5%
24.8%
25.1%
25.5%
25.8%
26.1%
26.4%
26.8%
27.1%
27.4%
27.8%
28.1%
28.4%
28.8%
29.1%
29.4%
29.8%
30.1%
30.4%
30.7%
31.1%
31.4%
31.7%
32.1%
32.4%
32.7%
33.1%
33.4%
33.7%
34.0%
34.4%
34.7%
35.0%
35.4%
35.7%
36.0%
36.4%
36.7%
37.0%
37.4%
37.7%
38.0%
38.3%
38.7%
39.0%
39.3%
39.7%
40.0%
40.3%
40.7%
41.0%
41.3%
41.7%
42.0%
42.3%
42.6%
43.0%
43.3%
43.6%
44.0%
44.3%
44.6%
45.0%
45.3%
45.6%
45.9%
46.3%
46.6%
46.9%
47.3%
47.6%
47.9%
48.3%
48.6%
48.9%
49.3%
49.6%
49.9%
50.2%
50.6%
50.9%
51.2%
51.6%
51.9%
52.2%
52.6%
52.9%
53.2%
53.6%
53.9%
54.2%
54.5%
54.9%
55.2%
55.5%
55.9%
56.2%
56.5%
56.9%
57.2%
57.5%
57.9%
58.2%
58.5%
58.8%
59.2%
59.5%
59.8%
60.2%
60.5%
60.8%
61.2%
61.5%
61.8%
62.1%
62.5%
62.8%
63.1%
63.5%
63.8%
64.1%
64.5%
64.8%
65.1%
65.5%
65.8%
66.1%
66.4%
66.8%
67.1%
67.4%
67.8%
68.1%
68.4%
68.8%
69.1%
69.4%
69.8%
70.1%
70.4%
70.7%
71.1%
71.4%
71.7%
72.1%
72.4%
72.7%
73.1%
73.4%
73.7%
74.0%
74.4%
74.7%
75.0%
75.4%
75.7%
76.0%
76.4%
76.7%
77.0%
77.4%
77.7%
78.0%
78.3%
78.7%
79.0%
79.3%
79.7%
80.0%
80.3%
80.7%
81.0%
81.3%
81.7%
82.0%
82.3%
82.6%
83.0%
83.3%
83.6%
84.0%
84.3%
84.6%
85.0%
85.3%
85.6%
85.9%
86.3%
86.6%
86.9%
87.3%
87.6%
87.9%
88.3%
88.6%
88.9%
89.3%
89.6%
89.9%
90.2%
90.6%
90.9%
91.2%
91.6%
91.9%
92.2%
92.6%
92.9%
93.2%
93.6%
93.9%
94.2%
94.5%
94.9%
95.2%
95.5%
95.9%
96.2%
96.5%
96.9%
97.2%
97.5%
97.9%
98.2%
98.5%
98.8%
99.2%
99.5%
99.8%
100.0%
Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz
100.0%
Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz
2.0%
4.0%
6.0%
7.9%
9.9%
11.9%
13.9%
15.9%
17.9%
19.9%
21.9%
23.8%
25.8%
27.8%
29.8%
31.8%
33.8%
35.8%
37.8%
39.7%
41.7%
43.7%
45.7%
47.7%
49.7%
51.7%
53.7%
55.6%
57.6%
59.6%
61.6%
63.6%
65.6%
67.6%
69.6%
71.5%
73.5%
75.5%
77.5%
79.5%
81.5%
83.5%
85.5%
87.4%
89.4%
91.4%
93.4%
95.4%
97.4%
99.4%
100.0%
Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz
100.0%
Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw
First order extensions
Batch gradients
loss = lossfunc(model(X), y)
with backpack(BatchGrad()):
loss.backward()
for name, param in model.named_parameters():
print(name)
print(".grad.shape: ", param.grad.shape)
print(".grad_batch.shape: ", param.grad_batch.shape)
1.weight
.grad.shape: torch.Size([10, 784])
.grad_batch.shape: torch.Size([512, 10, 784])
1.bias
.grad.shape: torch.Size([10])
.grad_batch.shape: torch.Size([512, 10])
Variance
loss = lossfunc(model(X), y)
with backpack(Variance()):
loss.backward()
for name, param in model.named_parameters():
print(name)
print(".grad.shape: ", param.grad.shape)
print(".variance.shape: ", param.variance.shape)
1.weight
.grad.shape: torch.Size([10, 784])
.variance.shape: torch.Size([10, 784])
1.bias
.grad.shape: torch.Size([10])
.variance.shape: torch.Size([10])
Second moment/sum of gradients squared
loss = lossfunc(model(X), y)
with backpack(SumGradSquared()):
loss.backward()
for name, param in model.named_parameters():
print(name)
print(".grad.shape: ", param.grad.shape)
print(".sum_grad_squared.shape: ", param.sum_grad_squared.shape)
1.weight
.grad.shape: torch.Size([10, 784])
.sum_grad_squared.shape: torch.Size([10, 784])
1.bias
.grad.shape: torch.Size([10])
.sum_grad_squared.shape: torch.Size([10])
L2 norm of individual gradients
loss = lossfunc(model(X), y)
with backpack(BatchL2Grad()):
loss.backward()
for name, param in model.named_parameters():
print(name)
print(".grad.shape: ", param.grad.shape)
print(".batch_l2.shape: ", param.batch_l2.shape)
1.weight
.grad.shape: torch.Size([10, 784])
.batch_l2.shape: torch.Size([512])
1.bias
.grad.shape: torch.Size([10])
.batch_l2.shape: torch.Size([512])
It’s also possible to ask for multiple quantities at once
loss = lossfunc(model(X), y)
with backpack(BatchGrad(), Variance(), SumGradSquared(), BatchL2Grad()):
loss.backward()
for name, param in model.named_parameters():
print(name)
print(".grad.shape: ", param.grad.shape)
print(".grad_batch.shape: ", param.grad_batch.shape)
print(".variance.shape: ", param.variance.shape)
print(".sum_grad_squared.shape: ", param.sum_grad_squared.shape)
print(".batch_l2.shape: ", param.batch_l2.shape)
1.weight
.grad.shape: torch.Size([10, 784])
.grad_batch.shape: torch.Size([512, 10, 784])
.variance.shape: torch.Size([10, 784])
.sum_grad_squared.shape: torch.Size([10, 784])
.batch_l2.shape: torch.Size([512])
1.bias
.grad.shape: torch.Size([10])
.grad_batch.shape: torch.Size([512, 10])
.variance.shape: torch.Size([10])
.sum_grad_squared.shape: torch.Size([10])
.batch_l2.shape: torch.Size([512])
Second order extensions
Diagonal of the generalized Gauss-Newton and its Monte-Carlo approximation
loss = lossfunc(model(X), y)
with backpack(DiagGGNExact(), DiagGGNMC(mc_samples=1)):
loss.backward()
for name, param in model.named_parameters():
print(name)
print(".grad.shape: ", param.grad.shape)
print(".diag_ggn_mc.shape: ", param.diag_ggn_mc.shape)
print(".diag_ggn_exact.shape: ", param.diag_ggn_exact.shape)
1.weight
.grad.shape: torch.Size([10, 784])
.diag_ggn_mc.shape: torch.Size([10, 784])
.diag_ggn_exact.shape: torch.Size([10, 784])
1.bias
.grad.shape: torch.Size([10])
.diag_ggn_mc.shape: torch.Size([10])
.diag_ggn_exact.shape: torch.Size([10])
Per-sample diagonal of the generalized Gauss-Newton and its Monte-Carlo approximation
loss = lossfunc(model(X), y)
with backpack(BatchDiagGGNExact(), BatchDiagGGNMC(mc_samples=1)):
loss.backward()
for name, param in model.named_parameters():
print(name)
print(".diag_ggn_mc_batch.shape: ", param.diag_ggn_mc_batch.shape)
print(".diag_ggn_exact_batch.shape: ", param.diag_ggn_exact_batch.shape)
1.weight
.diag_ggn_mc_batch.shape: torch.Size([512, 10, 784])
.diag_ggn_exact_batch.shape: torch.Size([512, 10, 784])
1.bias
.diag_ggn_mc_batch.shape: torch.Size([512, 10])
.diag_ggn_exact_batch.shape: torch.Size([512, 10])
KFAC, KFRA and KFLR
loss = lossfunc(model(X), y)
with backpack(KFAC(mc_samples=1), KFLR(), KFRA()):
loss.backward()
for name, param in model.named_parameters():
print(name)
print(".grad.shape: ", param.grad.shape)
print(".kfac (shapes): ", [kfac.shape for kfac in param.kfac])
print(".kflr (shapes): ", [kflr.shape for kflr in param.kflr])
print(".kfra (shapes): ", [kfra.shape for kfra in param.kfra])
1.weight
.grad.shape: torch.Size([10, 784])
.kfac (shapes): [torch.Size([10, 10]), torch.Size([784, 784])]
.kflr (shapes): [torch.Size([10, 10]), torch.Size([784, 784])]
.kfra (shapes): [torch.Size([10, 10]), torch.Size([784, 784])]
1.bias
.grad.shape: torch.Size([10])
.kfac (shapes): [torch.Size([10, 10])]
.kflr (shapes): [torch.Size([10, 10])]
.kfra (shapes): [torch.Size([10, 10])]
Diagonal Hessian and per-sample diagonal Hessian
loss = lossfunc(model(X), y)
with backpack(DiagHessian(), BatchDiagHessian()):
loss.backward()
for name, param in model.named_parameters():
print(name)
print(".grad.shape: ", param.grad.shape)
print(".diag_h.shape: ", param.diag_h.shape)
print(".diag_h_batch.shape: ", param.diag_h_batch.shape)
1.weight
.grad.shape: torch.Size([10, 784])
.diag_h.shape: torch.Size([10, 784])
.diag_h_batch.shape: torch.Size([512, 10, 784])
1.bias
.grad.shape: torch.Size([10])
.diag_h.shape: torch.Size([10])
.diag_h_batch.shape: torch.Size([512, 10])
Matrix square root of the generalized Gauss-Newton or its Monte-Carlo approximation
loss = lossfunc(model(X), y)
with backpack(SqrtGGNExact(), SqrtGGNMC(mc_samples=1)):
loss.backward()
for name, param in model.named_parameters():
print(name)
print(".grad.shape: ", param.grad.shape)
print(".sqrt_ggn_exact.shape: ", param.sqrt_ggn_exact.shape)
print(".sqrt_ggn_mc.shape: ", param.sqrt_ggn_mc.shape)
1.weight
.grad.shape: torch.Size([10, 784])
.sqrt_ggn_exact.shape: torch.Size([10, 512, 10, 784])
.sqrt_ggn_mc.shape: torch.Size([1, 512, 10, 784])
1.bias
.grad.shape: torch.Size([10])
.sqrt_ggn_exact.shape: torch.Size([10, 512, 10])
.sqrt_ggn_mc.shape: torch.Size([1, 512, 10])
Block-diagonal curvature products
Curvature-matrix product (MP
) extensions provide functions
that multiply with the block diagonal of different curvature matrices, such as
the Hessian (
HMP
)the generalized Gauss-Newton (
GGNMP
)the positive-curvature Hessian (
PCHMP
)
Multiply a random vector with curvature blocks.
V = 1
for name, param in model.named_parameters():
vec = rand(V, *param.shape)
print(name)
print(".grad.shape: ", param.grad.shape)
print("vec.shape: ", vec.shape)
print(".hmp(vec).shape: ", param.hmp(vec).shape)
print(".ggnmp(vec).shape: ", param.ggnmp(vec).shape)
print(".pchmp_clip(vec).shape: ", param.pchmp_clip(vec).shape)
print(".pchmp_abs(vec).shape: ", param.pchmp_abs(vec).shape)
1.weight
.grad.shape: torch.Size([10, 784])
vec.shape: torch.Size([1, 10, 784])
.hmp(vec).shape: torch.Size([1, 10, 784])
.ggnmp(vec).shape: torch.Size([1, 10, 784])
.pchmp_clip(vec).shape: torch.Size([1, 10, 784])
.pchmp_abs(vec).shape: torch.Size([1, 10, 784])
1.bias
.grad.shape: torch.Size([10])
vec.shape: torch.Size([1, 10])
.hmp(vec).shape: torch.Size([1, 10])
.ggnmp(vec).shape: torch.Size([1, 10])
.pchmp_clip(vec).shape: torch.Size([1, 10])
.pchmp_abs(vec).shape: torch.Size([1, 10])
Multiply a collection of three vectors (a matrix) with curvature blocks.
V = 3
for name, param in model.named_parameters():
vec = rand(V, *param.shape)
print(name)
print(".grad.shape: ", param.grad.shape)
print("vec.shape: ", vec.shape)
print(".hmp(vec).shape: ", param.hmp(vec).shape)
print(".ggnmp(vec).shape: ", param.ggnmp(vec).shape)
print(".pchmp_clip(vec).shape: ", param.pchmp_clip(vec).shape)
print(".pchmp_abs(vec).shape: ", param.pchmp_abs(vec).shape)
1.weight
.grad.shape: torch.Size([10, 784])
vec.shape: torch.Size([3, 10, 784])
.hmp(vec).shape: torch.Size([3, 10, 784])
.ggnmp(vec).shape: torch.Size([3, 10, 784])
.pchmp_clip(vec).shape: torch.Size([3, 10, 784])
.pchmp_abs(vec).shape: torch.Size([3, 10, 784])
1.bias
.grad.shape: torch.Size([10])
vec.shape: torch.Size([3, 10])
.hmp(vec).shape: torch.Size([3, 10])
.ggnmp(vec).shape: torch.Size([3, 10])
.pchmp_clip(vec).shape: torch.Size([3, 10])
.pchmp_abs(vec).shape: torch.Size([3, 10])
Total running time of the script: (0 minutes 6.774 seconds)