.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "basic_usage/example_all_in_one.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_basic_usage_example_all_in_one.py: Example using all extensions ============================== Basic example showing how to compute the gradient, and other quantities with BackPACK, on a linear model for MNIST. .. GENERATED FROM PYTHON SOURCE LINES 11-12 Let's start by loading some dummy data and extending the model .. GENERATED FROM PYTHON SOURCE LINES 12-47 .. code-block:: Python from torch import rand from torch.nn import CrossEntropyLoss, Flatten, Linear, Sequential from backpack import backpack, extend from backpack.extensions import ( GGNMP, HMP, KFAC, KFLR, KFRA, PCHMP, BatchDiagGGNExact, BatchDiagGGNMC, BatchDiagHessian, BatchGrad, BatchL2Grad, DiagGGNExact, DiagGGNMC, DiagHessian, SqrtGGNExact, SqrtGGNMC, SumGradSquared, Variance, ) from backpack.utils.examples import load_one_batch_mnist X, y = load_one_batch_mnist(batch_size=512) model = Sequential(Flatten(), Linear(784, 10)) lossfunc = CrossEntropyLoss() model = extend(model) lossfunc = extend(lossfunc) .. rst-class:: sphx-glr-script-out .. code-block:: none Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz Failed to download (trying next): HTTP Error 403: Forbidden Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz 0.3% 0.7% 1.0% 1.3% 1.7% 2.0% 2.3% 2.6% 3.0% 3.3% 3.6% 4.0% 4.3% 4.6% 5.0% 5.3% 5.6% 6.0% 6.3% 6.6% 6.9% 7.3% 7.6% 7.9% 8.3% 8.6% 8.9% 9.3% 9.6% 9.9% 10.2% 10.6% 10.9% 11.2% 11.6% 11.9% 12.2% 12.6% 12.9% 13.2% 13.6% 13.9% 14.2% 14.5% 14.9% 15.2% 15.5% 15.9% 16.2% 16.5% 16.9% 17.2% 17.5% 17.9% 18.2% 18.5% 18.8% 19.2% 19.5% 19.8% 20.2% 20.5% 20.8% 21.2% 21.5% 21.8% 22.1% 22.5% 22.8% 23.1% 23.5% 23.8% 24.1% 24.5% 24.8% 25.1% 25.5% 25.8% 26.1% 26.4% 26.8% 27.1% 27.4% 27.8% 28.1% 28.4% 28.8% 29.1% 29.4% 29.8% 30.1% 30.4% 30.7% 31.1% 31.4% 31.7% 32.1% 32.4% 32.7% 33.1% 33.4% 33.7% 34.0% 34.4% 34.7% 35.0% 35.4% 35.7% 36.0% 36.4% 36.7% 37.0% 37.4% 37.7% 38.0% 38.3% 38.7% 39.0% 39.3% 39.7% 40.0% 40.3% 40.7% 41.0% 41.3% 41.7% 42.0% 42.3% 42.6% 43.0% 43.3% 43.6% 44.0% 44.3% 44.6% 45.0% 45.3% 45.6% 45.9% 46.3% 46.6% 46.9% 47.3% 47.6% 47.9% 48.3% 48.6% 48.9% 49.3% 49.6% 49.9% 50.2% 50.6% 50.9% 51.2% 51.6% 51.9% 52.2% 52.6% 52.9% 53.2% 53.6% 53.9% 54.2% 54.5% 54.9% 55.2% 55.5% 55.9% 56.2% 56.5% 56.9% 57.2% 57.5% 57.9% 58.2% 58.5% 58.8% 59.2% 59.5% 59.8% 60.2% 60.5% 60.8% 61.2% 61.5% 61.8% 62.1% 62.5% 62.8% 63.1% 63.5% 63.8% 64.1% 64.5% 64.8% 65.1% 65.5% 65.8% 66.1% 66.4% 66.8% 67.1% 67.4% 67.8% 68.1% 68.4% 68.8% 69.1% 69.4% 69.8% 70.1% 70.4% 70.7% 71.1% 71.4% 71.7% 72.1% 72.4% 72.7% 73.1% 73.4% 73.7% 74.0% 74.4% 74.7% 75.0% 75.4% 75.7% 76.0% 76.4% 76.7% 77.0% 77.4% 77.7% 78.0% 78.3% 78.7% 79.0% 79.3% 79.7% 80.0% 80.3% 80.7% 81.0% 81.3% 81.7% 82.0% 82.3% 82.6% 83.0% 83.3% 83.6% 84.0% 84.3% 84.6% 85.0% 85.3% 85.6% 85.9% 86.3% 86.6% 86.9% 87.3% 87.6% 87.9% 88.3% 88.6% 88.9% 89.3% 89.6% 89.9% 90.2% 90.6% 90.9% 91.2% 91.6% 91.9% 92.2% 92.6% 92.9% 93.2% 93.6% 93.9% 94.2% 94.5% 94.9% 95.2% 95.5% 95.9% 96.2% 96.5% 96.9% 97.2% 97.5% 97.9% 98.2% 98.5% 98.8% 99.2% 99.5% 99.8% 100.0% Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz Failed to download (trying next): HTTP Error 403: Forbidden Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz 100.0% Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz Failed to download (trying next): HTTP Error 403: Forbidden Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz 2.0% 4.0% 6.0% 7.9% 9.9% 11.9% 13.9% 15.9% 17.9% 19.9% 21.9% 23.8% 25.8% 27.8% 29.8% 31.8% 33.8% 35.8% 37.8% 39.7% 41.7% 43.7% 45.7% 47.7% 49.7% 51.7% 53.7% 55.6% 57.6% 59.6% 61.6% 63.6% 65.6% 67.6% 69.6% 71.5% 73.5% 75.5% 77.5% 79.5% 81.5% 83.5% 85.5% 87.4% 89.4% 91.4% 93.4% 95.4% 97.4% 99.4% 100.0% Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz Failed to download (trying next): HTTP Error 403: Forbidden Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz 100.0% Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw .. GENERATED FROM PYTHON SOURCE LINES 48-50 First order extensions ---------------------- .. GENERATED FROM PYTHON SOURCE LINES 52-53 Batch gradients .. GENERATED FROM PYTHON SOURCE LINES 53-63 .. code-block:: Python loss = lossfunc(model(X), y) with backpack(BatchGrad()): loss.backward() for name, param in model.named_parameters(): print(name) print(".grad.shape: ", param.grad.shape) print(".grad_batch.shape: ", param.grad_batch.shape) .. rst-class:: sphx-glr-script-out .. code-block:: none 1.weight .grad.shape: torch.Size([10, 784]) .grad_batch.shape: torch.Size([512, 10, 784]) 1.bias .grad.shape: torch.Size([10]) .grad_batch.shape: torch.Size([512, 10]) .. GENERATED FROM PYTHON SOURCE LINES 64-65 Variance .. GENERATED FROM PYTHON SOURCE LINES 65-75 .. code-block:: Python loss = lossfunc(model(X), y) with backpack(Variance()): loss.backward() for name, param in model.named_parameters(): print(name) print(".grad.shape: ", param.grad.shape) print(".variance.shape: ", param.variance.shape) .. rst-class:: sphx-glr-script-out .. code-block:: none 1.weight .grad.shape: torch.Size([10, 784]) .variance.shape: torch.Size([10, 784]) 1.bias .grad.shape: torch.Size([10]) .variance.shape: torch.Size([10]) .. GENERATED FROM PYTHON SOURCE LINES 76-77 Second moment/sum of gradients squared .. GENERATED FROM PYTHON SOURCE LINES 77-87 .. code-block:: Python loss = lossfunc(model(X), y) with backpack(SumGradSquared()): loss.backward() for name, param in model.named_parameters(): print(name) print(".grad.shape: ", param.grad.shape) print(".sum_grad_squared.shape: ", param.sum_grad_squared.shape) .. rst-class:: sphx-glr-script-out .. code-block:: none 1.weight .grad.shape: torch.Size([10, 784]) .sum_grad_squared.shape: torch.Size([10, 784]) 1.bias .grad.shape: torch.Size([10]) .sum_grad_squared.shape: torch.Size([10]) .. GENERATED FROM PYTHON SOURCE LINES 88-89 L2 norm of individual gradients .. GENERATED FROM PYTHON SOURCE LINES 89-99 .. code-block:: Python loss = lossfunc(model(X), y) with backpack(BatchL2Grad()): loss.backward() for name, param in model.named_parameters(): print(name) print(".grad.shape: ", param.grad.shape) print(".batch_l2.shape: ", param.batch_l2.shape) .. rst-class:: sphx-glr-script-out .. code-block:: none 1.weight .grad.shape: torch.Size([10, 784]) .batch_l2.shape: torch.Size([512]) 1.bias .grad.shape: torch.Size([10]) .batch_l2.shape: torch.Size([512]) .. GENERATED FROM PYTHON SOURCE LINES 100-101 It's also possible to ask for multiple quantities at once .. GENERATED FROM PYTHON SOURCE LINES 101-114 .. code-block:: Python loss = lossfunc(model(X), y) with backpack(BatchGrad(), Variance(), SumGradSquared(), BatchL2Grad()): loss.backward() for name, param in model.named_parameters(): print(name) print(".grad.shape: ", param.grad.shape) print(".grad_batch.shape: ", param.grad_batch.shape) print(".variance.shape: ", param.variance.shape) print(".sum_grad_squared.shape: ", param.sum_grad_squared.shape) print(".batch_l2.shape: ", param.batch_l2.shape) .. rst-class:: sphx-glr-script-out .. code-block:: none 1.weight .grad.shape: torch.Size([10, 784]) .grad_batch.shape: torch.Size([512, 10, 784]) .variance.shape: torch.Size([10, 784]) .sum_grad_squared.shape: torch.Size([10, 784]) .batch_l2.shape: torch.Size([512]) 1.bias .grad.shape: torch.Size([10]) .grad_batch.shape: torch.Size([512, 10]) .variance.shape: torch.Size([10]) .sum_grad_squared.shape: torch.Size([10]) .batch_l2.shape: torch.Size([512]) .. GENERATED FROM PYTHON SOURCE LINES 115-117 Second order extensions -------------------------- .. GENERATED FROM PYTHON SOURCE LINES 119-120 Diagonal of the generalized Gauss-Newton and its Monte-Carlo approximation .. GENERATED FROM PYTHON SOURCE LINES 120-131 .. code-block:: Python loss = lossfunc(model(X), y) with backpack(DiagGGNExact(), DiagGGNMC(mc_samples=1)): loss.backward() for name, param in model.named_parameters(): print(name) print(".grad.shape: ", param.grad.shape) print(".diag_ggn_mc.shape: ", param.diag_ggn_mc.shape) print(".diag_ggn_exact.shape: ", param.diag_ggn_exact.shape) .. rst-class:: sphx-glr-script-out .. code-block:: none 1.weight .grad.shape: torch.Size([10, 784]) .diag_ggn_mc.shape: torch.Size([10, 784]) .diag_ggn_exact.shape: torch.Size([10, 784]) 1.bias .grad.shape: torch.Size([10]) .diag_ggn_mc.shape: torch.Size([10]) .diag_ggn_exact.shape: torch.Size([10]) .. GENERATED FROM PYTHON SOURCE LINES 132-133 Per-sample diagonal of the generalized Gauss-Newton and its Monte-Carlo approximation .. GENERATED FROM PYTHON SOURCE LINES 133-144 .. code-block:: Python loss = lossfunc(model(X), y) with backpack(BatchDiagGGNExact(), BatchDiagGGNMC(mc_samples=1)): loss.backward() for name, param in model.named_parameters(): print(name) print(".diag_ggn_mc_batch.shape: ", param.diag_ggn_mc_batch.shape) print(".diag_ggn_exact_batch.shape: ", param.diag_ggn_exact_batch.shape) .. rst-class:: sphx-glr-script-out .. code-block:: none 1.weight .diag_ggn_mc_batch.shape: torch.Size([512, 10, 784]) .diag_ggn_exact_batch.shape: torch.Size([512, 10, 784]) 1.bias .diag_ggn_mc_batch.shape: torch.Size([512, 10]) .diag_ggn_exact_batch.shape: torch.Size([512, 10]) .. GENERATED FROM PYTHON SOURCE LINES 145-146 KFAC, KFRA and KFLR .. GENERATED FROM PYTHON SOURCE LINES 146-158 .. code-block:: Python loss = lossfunc(model(X), y) with backpack(KFAC(mc_samples=1), KFLR(), KFRA()): loss.backward() for name, param in model.named_parameters(): print(name) print(".grad.shape: ", param.grad.shape) print(".kfac (shapes): ", [kfac.shape for kfac in param.kfac]) print(".kflr (shapes): ", [kflr.shape for kflr in param.kflr]) print(".kfra (shapes): ", [kfra.shape for kfra in param.kfra]) .. rst-class:: sphx-glr-script-out .. code-block:: none 1.weight .grad.shape: torch.Size([10, 784]) .kfac (shapes): [torch.Size([10, 10]), torch.Size([784, 784])] .kflr (shapes): [torch.Size([10, 10]), torch.Size([784, 784])] .kfra (shapes): [torch.Size([10, 10]), torch.Size([784, 784])] 1.bias .grad.shape: torch.Size([10]) .kfac (shapes): [torch.Size([10, 10])] .kflr (shapes): [torch.Size([10, 10])] .kfra (shapes): [torch.Size([10, 10])] .. GENERATED FROM PYTHON SOURCE LINES 159-160 Diagonal Hessian and per-sample diagonal Hessian .. GENERATED FROM PYTHON SOURCE LINES 160-171 .. code-block:: Python loss = lossfunc(model(X), y) with backpack(DiagHessian(), BatchDiagHessian()): loss.backward() for name, param in model.named_parameters(): print(name) print(".grad.shape: ", param.grad.shape) print(".diag_h.shape: ", param.diag_h.shape) print(".diag_h_batch.shape: ", param.diag_h_batch.shape) .. rst-class:: sphx-glr-script-out .. code-block:: none 1.weight .grad.shape: torch.Size([10, 784]) .diag_h.shape: torch.Size([10, 784]) .diag_h_batch.shape: torch.Size([512, 10, 784]) 1.bias .grad.shape: torch.Size([10]) .diag_h.shape: torch.Size([10]) .diag_h_batch.shape: torch.Size([512, 10]) .. GENERATED FROM PYTHON SOURCE LINES 172-173 Matrix square root of the generalized Gauss-Newton or its Monte-Carlo approximation .. GENERATED FROM PYTHON SOURCE LINES 173-184 .. code-block:: Python loss = lossfunc(model(X), y) with backpack(SqrtGGNExact(), SqrtGGNMC(mc_samples=1)): loss.backward() for name, param in model.named_parameters(): print(name) print(".grad.shape: ", param.grad.shape) print(".sqrt_ggn_exact.shape: ", param.sqrt_ggn_exact.shape) print(".sqrt_ggn_mc.shape: ", param.sqrt_ggn_mc.shape) .. rst-class:: sphx-glr-script-out .. code-block:: none 1.weight .grad.shape: torch.Size([10, 784]) .sqrt_ggn_exact.shape: torch.Size([10, 512, 10, 784]) .sqrt_ggn_mc.shape: torch.Size([1, 512, 10, 784]) 1.bias .grad.shape: torch.Size([10]) .sqrt_ggn_exact.shape: torch.Size([10, 512, 10]) .sqrt_ggn_mc.shape: torch.Size([1, 512, 10]) .. GENERATED FROM PYTHON SOURCE LINES 185-187 Block-diagonal curvature products --------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 189-195 Curvature-matrix product (``MP``) extensions provide functions that multiply with the block diagonal of different curvature matrices, such as - the Hessian (:code:`HMP`) - the generalized Gauss-Newton (:code:`GGNMP`) - the positive-curvature Hessian (:code:`PCHMP`) .. GENERATED FROM PYTHON SOURCE LINES 195-206 .. code-block:: Python loss = lossfunc(model(X), y) with backpack( HMP(), GGNMP(), PCHMP(savefield="pchmp_clip", modify="clip"), PCHMP(savefield="pchmp_abs", modify="abs"), ): loss.backward() .. GENERATED FROM PYTHON SOURCE LINES 207-208 Multiply a random vector with curvature blocks. .. GENERATED FROM PYTHON SOURCE LINES 208-221 .. code-block:: Python V = 1 for name, param in model.named_parameters(): vec = rand(V, *param.shape) print(name) print(".grad.shape: ", param.grad.shape) print("vec.shape: ", vec.shape) print(".hmp(vec).shape: ", param.hmp(vec).shape) print(".ggnmp(vec).shape: ", param.ggnmp(vec).shape) print(".pchmp_clip(vec).shape: ", param.pchmp_clip(vec).shape) print(".pchmp_abs(vec).shape: ", param.pchmp_abs(vec).shape) .. rst-class:: sphx-glr-script-out .. code-block:: none 1.weight .grad.shape: torch.Size([10, 784]) vec.shape: torch.Size([1, 10, 784]) .hmp(vec).shape: torch.Size([1, 10, 784]) .ggnmp(vec).shape: torch.Size([1, 10, 784]) .pchmp_clip(vec).shape: torch.Size([1, 10, 784]) .pchmp_abs(vec).shape: torch.Size([1, 10, 784]) 1.bias .grad.shape: torch.Size([10]) vec.shape: torch.Size([1, 10]) .hmp(vec).shape: torch.Size([1, 10]) .ggnmp(vec).shape: torch.Size([1, 10]) .pchmp_clip(vec).shape: torch.Size([1, 10]) .pchmp_abs(vec).shape: torch.Size([1, 10]) .. GENERATED FROM PYTHON SOURCE LINES 222-223 Multiply a collection of three vectors (a matrix) with curvature blocks. .. GENERATED FROM PYTHON SOURCE LINES 223-235 .. code-block:: Python V = 3 for name, param in model.named_parameters(): vec = rand(V, *param.shape) print(name) print(".grad.shape: ", param.grad.shape) print("vec.shape: ", vec.shape) print(".hmp(vec).shape: ", param.hmp(vec).shape) print(".ggnmp(vec).shape: ", param.ggnmp(vec).shape) print(".pchmp_clip(vec).shape: ", param.pchmp_clip(vec).shape) print(".pchmp_abs(vec).shape: ", param.pchmp_abs(vec).shape) .. rst-class:: sphx-glr-script-out .. code-block:: none 1.weight .grad.shape: torch.Size([10, 784]) vec.shape: torch.Size([3, 10, 784]) .hmp(vec).shape: torch.Size([3, 10, 784]) .ggnmp(vec).shape: torch.Size([3, 10, 784]) .pchmp_clip(vec).shape: torch.Size([3, 10, 784]) .pchmp_abs(vec).shape: torch.Size([3, 10, 784]) 1.bias .grad.shape: torch.Size([10]) vec.shape: torch.Size([3, 10]) .hmp(vec).shape: torch.Size([3, 10]) .ggnmp(vec).shape: torch.Size([3, 10]) .pchmp_clip(vec).shape: torch.Size([3, 10]) .pchmp_abs(vec).shape: torch.Size([3, 10]) .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 6.774 seconds) .. _sphx_glr_download_basic_usage_example_all_in_one.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: example_all_in_one.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: example_all_in_one.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: example_all_in_one.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_