.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "use_cases/example_diag_ggn_optimizer.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_use_cases_example_diag_ggn_optimizer.py: Diagonal Gauss-Newton Second order optimizer ================================================ A simple second-order optimizer with BackPACK on the `classic MNIST example from PyTorch `_. The optimizer we implement uses uses the diagonal of the GGN/Fisher matrix as a preconditioner, with a constant damping parameter; .. math:: x_{t+1} = x_t - \gamma (G(x_t) + \lambda I)^{-1} g(x_t), where .. math:: \begin{array}{ll} x_t: & \text{parameters of the model} \\ g(x_t): & \text{gradient} \\ G(x_t): & \text{diagonal of the Gauss-Newton/Fisher matrix at `x_t`} \\ \lambda: & \text{damping parameter} \\ \gamma: & \text{step-size} \\ \end{array} .. GENERATED FROM PYTHON SOURCE LINES 30-31 Let's get the imports, configuration and some helper functions out of the way first. .. GENERATED FROM PYTHON SOURCE LINES 31-71 .. code-block:: Python import matplotlib.pyplot as plt import torch from backpack import backpack, extend from backpack.extensions import DiagGGNMC from backpack.utils.examples import get_mnist_dataloader BATCH_SIZE = 128 STEP_SIZE = 0.05 DAMPING = 1.0 MAX_ITER = 200 PRINT_EVERY = 50 DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") torch.manual_seed(0) mnist_loader = get_mnist_dataloader(batch_size=BATCH_SIZE) model = torch.nn.Sequential( torch.nn.Conv2d(1, 20, 5, 1), torch.nn.ReLU(), torch.nn.MaxPool2d(2, 2), torch.nn.Conv2d(20, 50, 5, 1), torch.nn.ReLU(), torch.nn.MaxPool2d(2, 2), torch.nn.Flatten(), torch.nn.Linear(4 * 4 * 50, 500), torch.nn.ReLU(), torch.nn.Linear(500, 10), ).to(DEVICE) loss_function = torch.nn.CrossEntropyLoss().to(DEVICE) def get_accuracy(output, targets): """Helper function to print the accuracy""" predictions = output.argmax(dim=1, keepdim=True).view_as(targets) return predictions.eq(targets).float().mean().item() .. GENERATED FROM PYTHON SOURCE LINES 72-89 Writing the optimizer --------------------- To compute the update, we will need access to the diagonal of the Gauss-Newton, which will be provided by Backpack in the ``diag_ggn_mc`` field, in addition to the ``grad`` field created py PyTorch. We can use it to compute the update direction .. math:: (G(x_t) + \lambda I)^{-1} g(x_t) for a parameter ``p`` as .. math:: \texttt{p.grad / (p.diag_ggn_mc + damping)} .. GENERATED FROM PYTHON SOURCE LINES 89-102 .. code-block:: Python class DiagGGNOptimizer(torch.optim.Optimizer): def __init__(self, parameters, step_size, damping): super().__init__(parameters, dict(step_size=step_size, damping=damping)) def step(self): for group in self.param_groups: for p in group["params"]: step_direction = p.grad / (p.diag_ggn_mc + group["damping"]) p.data.add_(step_direction, alpha=-group["step_size"]) .. GENERATED FROM PYTHON SOURCE LINES 103-110 Running and plotting -------------------- After ``extend``-ing the model and the loss function and creating the optimizer, the only difference with a standard PyTorch training loop will be the activation of the `DiagGGNMC`` extension using a ``with backpack(DiagGGNMC()):`` block, so that BackPACK stores the diagonal of the GGN in the ``diag_ggn_mc`` field during the backward pass. .. GENERATED FROM PYTHON SOURCE LINES 110-157 .. code-block:: Python extend(model) extend(loss_function) optimizer = DiagGGNOptimizer(model.parameters(), step_size=STEP_SIZE, damping=DAMPING) losses = [] accuracies = [] for batch_idx, (x, y) in enumerate(mnist_loader): optimizer.zero_grad() x, y = x.to(DEVICE), y.to(DEVICE) model.zero_grad() outputs = model(x) loss = loss_function(outputs, y) with backpack(DiagGGNMC()): loss.backward() optimizer.step() # Logging losses.append(loss.detach().item()) accuracies.append(get_accuracy(outputs, y)) if (batch_idx % PRINT_EVERY) == 0: print( "Iteration %3.d/%3.d " % (batch_idx, MAX_ITER) + "Minibatch Loss %.3f " % losses[-1] + "Accuracy %.3f" % accuracies[-1] ) if MAX_ITER is not None and batch_idx > MAX_ITER: break fig = plt.figure() axes = [fig.add_subplot(1, 2, 1), fig.add_subplot(1, 2, 2)] axes[0].plot(losses) axes[0].set_title("Loss") axes[0].set_xlabel("Iteration") axes[1].plot(accuracies) axes[1].set_title("Accuracy") axes[1].set_xlabel("Iteration") .. image-sg:: /use_cases/images/sphx_glr_example_diag_ggn_optimizer_001.png :alt: Loss, Accuracy :srcset: /use_cases/images/sphx_glr_example_diag_ggn_optimizer_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none /home/docs/checkouts/readthedocs.org/user_builds/backpack/checkouts/master/docs_src/examples/use_cases/example_diag_ggn_optimizer.py:129: UserWarning: Full backward hook is firing when gradients are computed with respect to module outputs since no inputs require gradients. See https://docs.pytorch.org/docs/main/generated/torch.nn.Module.html#torch.nn.Module.register_full_backward_hook for more details. loss.backward() Iteration 0/200 Minibatch Loss 2.313 Accuracy 0.078 Iteration 50/200 Minibatch Loss 0.585 Accuracy 0.844 Iteration 100/200 Minibatch Loss 0.359 Accuracy 0.883 Iteration 150/200 Minibatch Loss 0.240 Accuracy 0.898 Iteration 200/200 Minibatch Loss 0.252 Accuracy 0.938 Text(0.5, 23.52222222222222, 'Iteration') .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 13.333 seconds) .. _sphx_glr_download_use_cases_example_diag_ggn_optimizer.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: example_diag_ggn_optimizer.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: example_diag_ggn_optimizer.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: example_diag_ggn_optimizer.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_