Note
Click here to download the full example code
First order extensions with a ResNet¶
Let’s get the imports, configuration and some helper functions out of the way first.
import torch
from backpack import backpack, extend
from backpack.extensions import BatchGrad
from backpack.utils.examples import load_one_batch_mnist
import torch.nn.functional as F
BATCH_SIZE = 3
torch.manual_seed(0)
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def get_accuracy(output, targets):
"""Helper function to print the accuracy"""
predictions = output.argmax(dim=1, keepdim=True).view_as(targets)
return predictions.eq(targets).float().mean().item()
x, y = load_one_batch_mnist(batch_size=BATCH_SIZE)
x, y = x.to(DEVICE), y.to(DEVICE)
Out:
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz
0.0%
0.1%
0.2%
0.2%
0.3%
0.4%
0.5%
0.6%
0.7%
0.7%
0.8%
0.9%
1.0%
1.1%
1.2%
1.2%
1.3%
1.4%
1.5%
1.6%
1.7%
1.7%
1.8%
1.9%
2.0%
2.1%
2.1%
2.2%
2.3%
2.4%
2.5%
2.6%
2.6%
2.7%
2.8%
2.9%
3.0%
3.1%
3.1%
3.2%
3.3%
3.4%
3.5%
3.6%
3.6%
3.7%
3.8%
3.9%
4.0%
4.0%
4.1%
4.2%
4.3%
4.4%
4.5%
4.5%
4.6%
4.7%
4.8%
4.9%
5.0%
5.0%
5.1%
5.2%
5.3%
5.4%
5.5%
5.5%
5.6%
5.7%
5.8%
5.9%
6.0%
6.0%
6.1%
6.2%
6.3%
6.4%
6.4%
6.5%
6.6%
6.7%
6.8%
6.9%
6.9%
7.0%
7.1%
7.2%
7.3%
7.4%
7.4%
7.5%
7.6%
7.7%
7.8%
7.9%
7.9%
8.0%
8.1%
8.2%
8.3%
8.3%
8.4%
8.5%
8.6%
8.7%
8.8%
8.8%
8.9%
9.0%
9.1%
9.2%
9.3%
9.3%
9.4%
9.5%
9.6%
9.7%
9.8%
9.8%
9.9%
10.0%
10.1%
10.2%
10.2%
10.3%
10.4%
10.5%
10.6%
10.7%
10.7%
10.8%
10.9%
11.0%
11.1%
11.2%
11.2%
11.3%
11.4%
11.5%
11.6%
11.7%
11.7%
11.8%
11.9%
12.0%
12.1%
12.1%
12.2%
12.3%
12.4%
12.5%
12.6%
12.6%
12.7%
12.8%
12.9%
13.0%
13.1%
13.1%
13.2%
13.3%
13.4%
13.5%
13.6%
13.6%
13.7%
13.8%
13.9%
14.0%
14.0%
14.1%
14.2%
14.3%
14.4%
14.5%
14.5%
14.6%
14.7%
14.8%
14.9%
15.0%
15.0%
15.1%
15.2%
15.3%
15.4%
15.5%
15.5%
15.6%
15.7%
15.8%
15.9%
16.0%
16.0%
16.1%
16.2%
16.3%
16.4%
16.4%
16.5%
16.6%
16.7%
16.8%
16.9%
16.9%
17.0%
17.1%
17.2%
17.3%
17.4%
17.4%
17.5%
17.6%
17.7%
17.8%
17.9%
17.9%
18.0%
18.1%
18.2%
18.3%
18.3%
18.4%
18.5%
18.6%
18.7%
18.8%
18.8%
18.9%
19.0%
19.1%
19.2%
19.3%
19.3%
19.4%
19.5%
19.6%
19.7%
19.8%
19.8%
19.9%
20.0%
20.1%
20.2%
20.2%
20.3%
20.4%
20.5%
20.6%
20.7%
20.7%
20.8%
20.9%
21.0%
21.1%
21.2%
21.2%
21.3%
21.4%
21.5%
21.6%
21.7%
21.7%
21.8%
21.9%
22.0%
22.1%
22.1%
22.2%
22.3%
22.4%
22.5%
22.6%
22.6%
22.7%
22.8%
22.9%
23.0%
23.1%
23.1%
23.2%
23.3%
23.4%
23.5%
23.6%
23.6%
23.7%
23.8%
23.9%
24.0%
24.0%
24.1%
24.2%
24.3%
24.4%
24.5%
24.5%
24.6%
24.7%
24.8%
24.9%
25.0%
25.0%
25.1%
25.2%
25.3%
25.4%
25.5%
25.5%
25.6%
25.7%
25.8%
25.9%
26.0%
26.0%
26.1%
26.2%
26.3%
26.4%
26.4%
26.5%
26.6%
26.7%
26.8%
26.9%
26.9%
27.0%
27.1%
27.2%
27.3%
27.4%
27.4%
27.5%
27.6%
27.7%
27.8%
27.9%
27.9%
28.0%
28.1%
28.2%
28.3%
28.3%
28.4%
28.5%
28.6%
28.7%
28.8%
28.8%
28.9%
29.0%
29.1%
29.2%
29.3%
29.3%
29.4%
29.5%
29.6%
29.7%
29.8%
29.8%
29.9%
30.0%
30.1%
30.2%
30.2%
30.3%
30.4%
30.5%
30.6%
30.7%
30.7%
30.8%
30.9%
31.0%
31.1%
31.2%
31.2%
31.3%
31.4%
31.5%
31.6%
31.7%
31.7%
31.8%
31.9%
32.0%
32.1%
32.1%
32.2%
32.3%
32.4%
32.5%
32.6%
32.6%
32.7%
32.8%
32.9%
33.0%
33.1%
33.1%
33.2%
33.3%
33.4%
33.5%
33.6%
33.6%
33.7%
33.8%
33.9%
34.0%
34.0%
34.1%
34.2%
34.3%
34.4%
34.5%
34.5%
34.6%
34.7%
34.8%
34.9%
35.0%
35.0%
35.1%
35.2%
35.3%
35.4%
35.5%
35.5%
35.6%
35.7%
35.8%
35.9%
36.0%
36.0%
36.1%
36.2%
36.3%
36.4%
36.4%
36.5%
36.6%
36.7%
36.8%
36.9%
36.9%
37.0%
37.1%
37.2%
37.3%
37.4%
37.4%
37.5%
37.6%
37.7%
37.8%
37.9%
37.9%
38.0%
38.1%
38.2%
38.3%
38.3%
38.4%
38.5%
38.6%
38.7%
38.8%
38.8%
38.9%
39.0%
39.1%
39.2%
39.3%
39.3%
39.4%
39.5%
39.6%
39.7%
39.8%
39.8%
39.9%
40.0%
40.1%
40.2%
40.2%
40.3%
40.4%
40.5%
40.6%
40.7%
40.7%
40.8%
40.9%
41.0%
41.1%
41.2%
41.2%
41.3%
41.4%
41.5%
41.6%
41.7%
41.7%
41.8%
41.9%
42.0%
42.1%
42.1%
42.2%
42.3%
42.4%
42.5%
42.6%
42.6%
42.7%
42.8%
42.9%
43.0%
43.1%
43.1%
43.2%
43.3%
43.4%
43.5%
43.6%
43.6%
43.7%
43.8%
43.9%
44.0%
44.0%
44.1%
44.2%
44.3%
44.4%
44.5%
44.5%
44.6%
44.7%
44.8%
44.9%
45.0%
45.0%
45.1%
45.2%
45.3%
45.4%
45.5%
45.5%
45.6%
45.7%
45.8%
45.9%
45.9%
46.0%
46.1%
46.2%
46.3%
46.4%
46.4%
46.5%
46.6%
46.7%
46.8%
46.9%
46.9%
47.0%
47.1%
47.2%
47.3%
47.4%
47.4%
47.5%
47.6%
47.7%
47.8%
47.9%
47.9%
48.0%
48.1%
48.2%
48.3%
48.3%
48.4%
48.5%
48.6%
48.7%
48.8%
48.8%
48.9%
49.0%
49.1%
49.2%
49.3%
49.3%
49.4%
49.5%
49.6%
49.7%
49.8%
49.8%
49.9%
50.0%
50.1%
50.2%
50.2%
50.3%
50.4%
50.5%
50.6%
50.7%
50.7%
50.8%
50.9%
51.0%
51.1%
51.2%
51.2%
51.3%
51.4%
51.5%
51.6%
51.7%
51.7%
51.8%
51.9%
52.0%
52.1%
52.1%
52.2%
52.3%
52.4%
52.5%
52.6%
52.6%
52.7%
52.8%
52.9%
53.0%
53.1%
53.1%
53.2%
53.3%
53.4%
53.5%
53.6%
53.6%
53.7%
53.8%
53.9%
54.0%
54.0%
54.1%
54.2%
54.3%
54.4%
54.5%
54.5%
54.6%
54.7%
54.8%
54.9%
55.0%
55.0%
55.1%
55.2%
55.3%
55.4%
55.5%
55.5%
55.6%
55.7%
55.8%
55.9%
55.9%
56.0%
56.1%
56.2%
56.3%
56.4%
56.4%
56.5%
56.6%
56.7%
56.8%
56.9%
56.9%
57.0%
57.1%
57.2%
57.3%
57.4%
57.4%
57.5%
57.6%
57.7%
57.8%
57.9%
57.9%
58.0%
58.1%
58.2%
58.3%
58.3%
58.4%
58.5%
58.6%
58.7%
58.8%
58.8%
58.9%
59.0%
59.1%
59.2%
59.3%
59.3%
59.4%
59.5%
59.6%
59.7%
59.8%
59.8%
59.9%
60.0%
60.1%
60.2%
60.2%
60.3%
60.4%
60.5%
60.6%
60.7%
60.7%
60.8%
60.9%
61.0%
61.1%
61.2%
61.2%
61.3%
61.4%
61.5%
61.6%
61.7%
61.7%
61.8%
61.9%
62.0%
62.1%
62.1%
62.2%
62.3%
62.4%
62.5%
62.6%
62.6%
62.7%
62.8%
62.9%
63.0%
63.1%
63.1%
63.2%
63.3%
63.4%
63.5%
63.6%
63.6%
63.7%
63.8%
63.9%
64.0%
64.0%
64.1%
64.2%
64.3%
64.4%
64.5%
64.5%
64.6%
64.7%
64.8%
64.9%
65.0%
65.0%
65.1%
65.2%
65.3%
65.4%
65.5%
65.5%
65.6%
65.7%
65.8%
65.9%
65.9%
66.0%
66.1%
66.2%
66.3%
66.4%
66.4%
66.5%
66.6%
66.7%
66.8%
66.9%
66.9%
67.0%
67.1%
67.2%
67.3%
67.4%
67.4%
67.5%
67.6%
67.7%
67.8%
67.9%
67.9%
68.0%
68.1%
68.2%
68.3%
68.3%
68.4%
68.5%
68.6%
68.7%
68.8%
68.8%
68.9%
69.0%
69.1%
69.2%
69.3%
69.3%
69.4%
69.5%
69.6%
69.7%
69.8%
69.8%
69.9%
70.0%
70.1%
70.2%
70.2%
70.3%
70.4%
70.5%
70.6%
70.7%
70.7%
70.8%
70.9%
71.0%
71.1%
71.2%
71.2%
71.3%
71.4%
71.5%
71.6%
71.7%
71.7%
71.8%
71.9%
72.0%
72.1%
72.1%
72.2%
72.3%
72.4%
72.5%
72.6%
72.6%
72.7%
72.8%
72.9%
73.0%
73.1%
73.1%
73.2%
73.3%
73.4%
73.5%
73.6%
73.6%
73.7%
73.8%
73.9%
74.0%
74.0%
74.1%
74.2%
74.3%
74.4%
74.5%
74.5%
74.6%
74.7%
74.8%
74.9%
75.0%
75.0%
75.1%
75.2%
75.3%
75.4%
75.5%
75.5%
75.6%
75.7%
75.8%
75.9%
75.9%
76.0%
76.1%
76.2%
76.3%
76.4%
76.4%
76.5%
76.6%
76.7%
76.8%
76.9%
76.9%
77.0%
77.1%
77.2%
77.3%
77.4%
77.4%
77.5%
77.6%
77.7%
77.8%
77.9%
77.9%
78.0%
78.1%
78.2%
78.3%
78.3%
78.4%
78.5%
78.6%
78.7%
78.8%
78.8%
78.9%
79.0%
79.1%
79.2%
79.3%
79.3%
79.4%
79.5%
79.6%
79.7%
79.8%
79.8%
79.9%
80.0%
80.1%
80.2%
80.2%
80.3%
80.4%
80.5%
80.6%
80.7%
80.7%
80.8%
80.9%
81.0%
81.1%
81.2%
81.2%
81.3%
81.4%
81.5%
81.6%
81.7%
81.7%
81.8%
81.9%
82.0%
82.1%
82.1%
82.2%
82.3%
82.4%
82.5%
82.6%
82.6%
82.7%
82.8%
82.9%
83.0%
83.1%
83.1%
83.2%
83.3%
83.4%
83.5%
83.6%
83.6%
83.7%
83.8%
83.9%
84.0%
84.0%
84.1%
84.2%
84.3%
84.4%
84.5%
84.5%
84.6%
84.7%
84.8%
84.9%
85.0%
85.0%
85.1%
85.2%
85.3%
85.4%
85.5%
85.5%
85.6%
85.7%
85.8%
85.9%
85.9%
86.0%
86.1%
86.2%
86.3%
86.4%
86.4%
86.5%
86.6%
86.7%
86.8%
86.9%
86.9%
87.0%
87.1%
87.2%
87.3%
87.4%
87.4%
87.5%
87.6%
87.7%
87.8%
87.9%
87.9%
88.0%
88.1%
88.2%
88.3%
88.3%
88.4%
88.5%
88.6%
88.7%
88.8%
88.8%
88.9%
89.0%
89.1%
89.2%
89.3%
89.3%
89.4%
89.5%
89.6%
89.7%
89.8%
89.8%
89.9%
90.0%
90.1%
90.2%
90.2%
90.3%
90.4%
90.5%
90.6%
90.7%
90.7%
90.8%
90.9%
91.0%
91.1%
91.2%
91.2%
91.3%
91.4%
91.5%
91.6%
91.7%
91.7%
91.8%
91.9%
92.0%
92.1%
92.1%
92.2%
92.3%
92.4%
92.5%
92.6%
92.6%
92.7%
92.8%
92.9%
93.0%
93.1%
93.1%
93.2%
93.3%
93.4%
93.5%
93.6%
93.6%
93.7%
93.8%
93.9%
94.0%
94.0%
94.1%
94.2%
94.3%
94.4%
94.5%
94.5%
94.6%
94.7%
94.8%
94.9%
95.0%
95.0%
95.1%
95.2%
95.3%
95.4%
95.5%
95.5%
95.6%
95.7%
95.8%
95.9%
95.9%
96.0%
96.1%
96.2%
96.3%
96.4%
96.4%
96.5%
96.6%
96.7%
96.8%
96.9%
96.9%
97.0%
97.1%
97.2%
97.3%
97.4%
97.4%
97.5%
97.6%
97.7%
97.8%
97.9%
97.9%
98.0%
98.1%
98.2%
98.3%
98.3%
98.4%
98.5%
98.6%
98.7%
98.8%
98.8%
98.9%
99.0%
99.1%
99.2%
99.3%
99.3%
99.4%
99.5%
99.6%
99.7%
99.8%
99.8%
99.9%
100.0%
100.1%Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz
0.0%
28.4%
56.7%
85.1%
113.5%Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz
0.0%
0.5%
1.0%
1.5%
2.0%
2.5%
3.0%
3.5%
4.0%
4.5%
5.0%
5.5%
6.0%
6.5%
7.0%
7.5%
7.9%
8.4%
8.9%
9.4%
9.9%
10.4%
10.9%
11.4%
11.9%
12.4%
12.9%
13.4%
13.9%
14.4%
14.9%
15.4%
15.9%
16.4%
16.9%
17.4%
17.9%
18.4%
18.9%
19.4%
19.9%
20.4%
20.9%
21.4%
21.9%
22.4%
22.9%
23.4%
23.8%
24.3%
24.8%
25.3%
25.8%
26.3%
26.8%
27.3%
27.8%
28.3%
28.8%
29.3%
29.8%
30.3%
30.8%
31.3%
31.8%
32.3%
32.8%
33.3%
33.8%
34.3%
34.8%
35.3%
35.8%
36.3%
36.8%
37.3%
37.8%
38.3%
38.8%
39.2%
39.7%
40.2%
40.7%
41.2%
41.7%
42.2%
42.7%
43.2%
43.7%
44.2%
44.7%
45.2%
45.7%
46.2%
46.7%
47.2%
47.7%
48.2%
48.7%
49.2%
49.7%
50.2%
50.7%
51.2%
51.7%
52.2%
52.7%
53.2%
53.7%
54.2%
54.7%
55.1%
55.6%
56.1%
56.6%
57.1%
57.6%
58.1%
58.6%
59.1%
59.6%
60.1%
60.6%
61.1%
61.6%
62.1%
62.6%
63.1%
63.6%
64.1%
64.6%
65.1%
65.6%
66.1%
66.6%
67.1%
67.6%
68.1%
68.6%
69.1%
69.6%
70.1%
70.5%
71.0%
71.5%
72.0%
72.5%
73.0%
73.5%
74.0%
74.5%
75.0%
75.5%
76.0%
76.5%
77.0%
77.5%
78.0%
78.5%
79.0%
79.5%
80.0%
80.5%
81.0%
81.5%
82.0%
82.5%
83.0%
83.5%
84.0%
84.5%
85.0%
85.5%
86.0%
86.4%
86.9%
87.4%
87.9%
88.4%
88.9%
89.4%
89.9%
90.4%
90.9%
91.4%
91.9%
92.4%
92.9%
93.4%
93.9%
94.4%
94.9%
95.4%
95.9%
96.4%
96.9%
97.4%
97.9%
98.4%
98.9%
99.4%
99.9%
100.4%Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz
0.0%
180.4%Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw
Processing...
Done!
We can build a ResNet by extending torch.nn.Module
.
As long as the layers with parameters
(torch.nn.Conv2d
and torch.nn.Linear
) are
nn
modules, BackPACK can extend them,
and this is all that is needed for first order extensions.
We can rewrite the forward to implement the residual connection,
and extend()
the resulting model.
class MyFirstResNet(torch.nn.Module):
def __init__(self, C_in=1, C_hid=5, input_dim=(28, 28), output_dim=10):
super().__init__()
self.conv1 = torch.nn.Conv2d(C_in, C_hid, kernel_size=3, stride=1, padding=1)
self.conv2 = torch.nn.Conv2d(C_hid, C_hid, kernel_size=3, stride=1, padding=1)
self.linear1 = torch.nn.Linear(input_dim[0] * input_dim[1] * C_hid, output_dim)
if C_in == C_hid:
self.shortcut = torch.nn.Identity()
else:
self.shortcut = torch.nn.Conv2d(C_in, C_hid, kernel_size=1, stride=1)
def forward(self, x):
residual = self.shortcut(x)
x = self.conv2(F.relu(self.conv1(x)))
x += residual
x = x.view(x.size(0), -1)
x = self.linear1(x)
return x
model = extend(MyFirstResNet()).to(DEVICE)
Using BatchGrad
in a
with backpack(...)
block,
we can access the individual gradients for each sample.
The loss does not need to be extended in this case either, as it does not have model parameters and BackPACK does not need to know about it for first order extensions. This also means you can use any custom loss function.
model.zero_grad()
loss = F.cross_entropy(model(x), y, reduction="sum")
with backpack(BatchGrad()):
loss.backward()
print("{:<20} {:<30} {:<30}".format("Param", "grad", "grad (batch)"))
print("-" * 80)
for name, p in model.named_parameters():
print(
"{:<20}: {:<30} {:<30}".format(name, str(p.grad.shape), str(p.grad_batch.shape))
)
Out:
Param grad grad (batch)
--------------------------------------------------------------------------------
conv1.weight : torch.Size([5, 1, 3, 3]) torch.Size([3, 5, 1, 3, 3])
conv1.bias : torch.Size([5]) torch.Size([3, 5])
conv2.weight : torch.Size([5, 5, 3, 3]) torch.Size([3, 5, 5, 3, 3])
conv2.bias : torch.Size([5]) torch.Size([3, 5])
linear1.weight : torch.Size([10, 3920]) torch.Size([3, 10, 3920])
linear1.bias : torch.Size([10]) torch.Size([3, 10])
shortcut.weight : torch.Size([5, 1, 1, 1]) torch.Size([3, 5, 1, 1, 1])
shortcut.bias : torch.Size([5]) torch.Size([3, 5])
To check that everything works, let’s compute one individual gradient with PyTorch (using a single sample in a forward and backward pass) and compare it with the one computed by BackPACK.
sample_to_check = 1
x_to_check = x[sample_to_check, :].unsqueeze(0)
y_to_check = y[sample_to_check].unsqueeze(0)
model.zero_grad()
loss = F.cross_entropy(model(x_to_check), y_to_check)
loss.backward()
print("Do the individual gradient match?")
for param_id, (name, p) in enumerate(model.named_parameters()):
match = torch.allclose(p.grad_batch[sample_to_check, :], p.grad, atol=1e-7)
print("{:<20} {}".format(name, match))
Out:
Do the individual gradient match?
conv1.weight True
conv1.bias True
conv2.weight True
conv2.bias True
linear1.weight True
linear1.bias True
shortcut.weight True
shortcut.bias True
Total running time of the script: ( 0 minutes 3.261 seconds)