Mini-batch sub-sampling

By default, BackPACK’s extensions consider all samples in a mini-batch. Some extensions support limiting the computations to a subset of samples. This example shows how to restrict the computations to such a subset of samples.

This may be interesting for applications where parts of the samples are used for different purposes, e.g. computing curvature and gradient information on different subsets. Limiting the computations to fewer samples also reduces costs.

Note

Not all extensions support sub-sampling yet. Please create a feature request in the repository if the extension you need is not supported.

Let’s start by loading some dummy data and extending the model

from torch import allclose, cuda, device, manual_seed
from torch.nn import CrossEntropyLoss, Flatten, Linear, Sequential

from backpack import backpack, extend
from backpack.extensions import BatchGrad
from backpack.utils.examples import load_one_batch_mnist

# make deterministic
manual_seed(0)

dev = device("cuda" if cuda.is_available() else "cpu")

# data
X, y = load_one_batch_mnist(batch_size=128)
X, y = X.to(dev), y.to(dev)

# model
model = Sequential(Flatten(), Linear(784, 10)).to(dev)
lossfunc = CrossEntropyLoss().to(dev)

model = extend(model)
lossfunc = extend(lossfunc)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz

0.3%
0.7%
1.0%
1.3%
1.7%
2.0%
2.3%
2.6%
3.0%
3.3%
3.6%
4.0%
4.3%
4.6%
5.0%
5.3%
5.6%
6.0%
6.3%
6.6%
6.9%
7.3%
7.6%
7.9%
8.3%
8.6%
8.9%
9.3%
9.6%
9.9%
10.2%
10.6%
10.9%
11.2%
11.6%
11.9%
12.2%
12.6%
12.9%
13.2%
13.6%
13.9%
14.2%
14.5%
14.9%
15.2%
15.5%
15.9%
16.2%
16.5%
16.9%
17.2%
17.5%
17.9%
18.2%
18.5%
18.8%
19.2%
19.5%
19.8%
20.2%
20.5%
20.8%
21.2%
21.5%
21.8%
22.1%
22.5%
22.8%
23.1%
23.5%
23.8%
24.1%
24.5%
24.8%
25.1%
25.5%
25.8%
26.1%
26.4%
26.8%
27.1%
27.4%
27.8%
28.1%
28.4%
28.8%
29.1%
29.4%
29.8%
30.1%
30.4%
30.7%
31.1%
31.4%
31.7%
32.1%
32.4%
32.7%
33.1%
33.4%
33.7%
34.0%
34.4%
34.7%
35.0%
35.4%
35.7%
36.0%
36.4%
36.7%
37.0%
37.4%
37.7%
38.0%
38.3%
38.7%
39.0%
39.3%
39.7%
40.0%
40.3%
40.7%
41.0%
41.3%
41.7%
42.0%
42.3%
42.6%
43.0%
43.3%
43.6%
44.0%
44.3%
44.6%
45.0%
45.3%
45.6%
45.9%
46.3%
46.6%
46.9%
47.3%
47.6%
47.9%
48.3%
48.6%
48.9%
49.3%
49.6%
49.9%
50.2%
50.6%
50.9%
51.2%
51.6%
51.9%
52.2%
52.6%
52.9%
53.2%
53.6%
53.9%
54.2%
54.5%
54.9%
55.2%
55.5%
55.9%
56.2%
56.5%
56.9%
57.2%
57.5%
57.9%
58.2%
58.5%
58.8%
59.2%
59.5%
59.8%
60.2%
60.5%
60.8%
61.2%
61.5%
61.8%
62.1%
62.5%
62.8%
63.1%
63.5%
63.8%
64.1%
64.5%
64.8%
65.1%
65.5%
65.8%
66.1%
66.4%
66.8%
67.1%
67.4%
67.8%
68.1%
68.4%
68.8%
69.1%
69.4%
69.8%
70.1%
70.4%
70.7%
71.1%
71.4%
71.7%
72.1%
72.4%
72.7%
73.1%
73.4%
73.7%
74.0%
74.4%
74.7%
75.0%
75.4%
75.7%
76.0%
76.4%
76.7%
77.0%
77.4%
77.7%
78.0%
78.3%
78.7%
79.0%
79.3%
79.7%
80.0%
80.3%
80.7%
81.0%
81.3%
81.7%
82.0%
82.3%
82.6%
83.0%
83.3%
83.6%
84.0%
84.3%
84.6%
85.0%
85.3%
85.6%
85.9%
86.3%
86.6%
86.9%
87.3%
87.6%
87.9%
88.3%
88.6%
88.9%
89.3%
89.6%
89.9%
90.2%
90.6%
90.9%
91.2%
91.6%
91.9%
92.2%
92.6%
92.9%
93.2%
93.6%
93.9%
94.2%
94.5%
94.9%
95.2%
95.5%
95.9%
96.2%
96.5%
96.9%
97.2%
97.5%
97.9%
98.2%
98.5%
98.8%
99.2%
99.5%
99.8%
100.0%
Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz

100.0%
Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz

2.0%
4.0%
6.0%
7.9%
9.9%
11.9%
13.9%
15.9%
17.9%
19.9%
21.9%
23.8%
25.8%
27.8%
29.8%
31.8%
33.8%
35.8%
37.8%
39.7%
41.7%
43.7%
45.7%
47.7%
49.7%
51.7%
53.7%
55.6%
57.6%
59.6%
61.6%
63.6%
65.6%
67.6%
69.6%
71.5%
73.5%
75.5%
77.5%
79.5%
81.5%
83.5%
85.5%
87.4%
89.4%
91.4%
93.4%
95.4%
97.4%
99.4%
100.0%
Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz

100.0%
Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw

Individual gradients for a mini-batch subset

Let’s say we only want to compute individual gradients for samples 0, 1, 13, and 42. Naively, we could perform the computation for all samples, then slice out the samples we care about.

# selected samples
subsampling = [0, 1, 13, 42]

loss = lossfunc(model(X), y)

with backpack(BatchGrad()):
    loss.backward()

# naive approach: compute for all, slice out relevant
naive = [p.grad_batch[subsampling] for p in model.parameters()]

This is not efficient, as individual gradients are computed for all samples, most of them being discarded after. We can do better by specifying the active samples directly with the subsampling argument of BatchGrad.

loss = lossfunc(model(X), y)

# efficient approach: specify active samples in backward pass
with backpack(BatchGrad(subsampling=subsampling)):
    loss.backward()

efficient = [p.grad_batch for p in model.parameters()]

Let’s verify that both ways yield the same result:

match = all(
    allclose(g_naive, g_efficient) for g_naive, g_efficient in zip(naive, efficient)
)

print(f"Naive and efficient sub-sampled individual gradients match? {match}")

if not match:
    raise ValueError("Naive and efficient sub-sampled individual gradient don't match.")
Naive and efficient sub-sampled individual gradients match? True

Total running time of the script: ( 0 minutes 1.025 seconds)

Gallery generated by Sphinx-Gallery