{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"%matplotlib inline"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Batched Jacobians\n\nIn PyTorch, you can easily compute derivatives of a **scalar-valued** variable\n:code:`f` w.r.t. to a variable :code:`param` by calling\n:code:`f.backward()`. This computes the Jacobian :code:`\u2202(f) / \u2202(param)`\nthat has shape :code:`[1, *param.shape]`.\n\nIf :code:`f` is a reduction of a **batched scalar** :code:`fs` of shape\n:code:`[N]`, then BackPACK is capable to compute the individual gradients for\neach scalar with its :code:`BatchGrad` extension. This yields the Jacobian\n:code:`\u2202(fs) / \u2202(param)` of shape :code:`[N, *param.shape]`.\n\n**This example** demonstrates how to compute the Jacobian of a tensor-valued\nvariable :code:`fs`, here for the example of a **batched vector** of shape\n:code:`[N, C]`, whose Jacobian has shape :code:`[N, C, *param.shape]`.\n\n## Setup\n\nWe will use the batched vector-valued output of a simple MLP as tensor\n:code:`fs` that should be differentiated w.r.t. the model parameters\n:code:`param_1, param_2, ...`. For :code:`param_i`, this leads to a Jacobian\n:code:`\u2202(fs) / \u2202(param_i)` of shape :code:`[N, C, *param_i.shape]`.\n\nLet's start by importing the required functionality and write a setup function\nto create our synthetic data.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import itertools\nfrom math import sqrt\nfrom typing import List, Tuple\n\nimport matplotlib.pyplot as plt\nfrom torch import Tensor, allclose, cat, manual_seed, rand, zeros, zeros_like\nfrom torch.autograd import grad\nfrom torch.nn import Linear, MSELoss, ReLU, Sequential\n\nfrom backpack import backpack, extend, extensions\n\n# architecture specifications\nN = 15\nD_in = 10\nD_hidden = 7\nC = 5\n\n\ndef setup() -> Tuple[Sequential, Tensor]:\n \"\"\"Create a simple MLP with ReLU activations and its synthetic input.\n\n Returns:\n A simple MLP and a tensor that can be fed to it.\n \"\"\"\n X = rand(N, D_in)\n model = Sequential(Linear(D_in, D_hidden), ReLU(), Linear(D_hidden, C))\n\n return model, X"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## With autograd\n\nFirst, let's compute the Jacobians with PyTorch's :code:`autograd` to verify\nour results.\n\nTo do that, we need to differentiate per component of :code:`fs`. This means\nthat we will differentiate multiple times through its graph, therefore we\nneed to set :code:`retain_graph=True`.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"manual_seed(0)\nmodel, X = setup()\n\nfs = model(X)\nautograd_jacobians = [zeros(fs.shape + param.shape) for param in model.parameters()]\n\nfor n, c in itertools.product(range(N), range(C)):\n grads_n_c = grad(fs[n, c], model.parameters(), retain_graph=True)\n for param_idx, param_grad in enumerate(grads_n_c):\n autograd_jacobians[param_idx][n, c, :] = param_grad"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's visualize the Jacobians by flattening the dimensions stemming from\n:code:`fs` and from :code:`param_i`, and by concatenating them along the\nparameter dimensions:\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"plt.figure()\nplt.title(r\"Batched Jacobian\")\nimage = plt.imshow(\n cat(\n [\n jac.flatten(end_dim=fs.dim() - 1).flatten(start_dim=1)\n for jac in autograd_jacobians\n ],\n dim=1,\n )\n)\nplt.colorbar(image, shrink=0.7)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In the following, we will compute the same Jacobian tensor lists with\nBackPACK. To compare our results, we will use the following helper function:\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"def compare_tensor_lists(\n tensor_list1: List[Tensor], tensor_list_2: List[Tensor]\n) -> None:\n \"\"\"Checks equality of two tensor lists.\n\n Args:\n tensor_list1: First tensor list.\n tensor_list2: Second tensor list.\n\n Raises:\n ValueError: If the two tensor lists don't match.\n \"\"\"\n if len(tensor_list1) != len(tensor_list_2):\n raise ValueError(\"Tensor lists have different length.\")\n for tensor1, tensor2 in zip(tensor_list1, tensor_list_2):\n if tensor1.shape != tensor2.shape:\n raise ValueError(\"Tensors have different sizes.\")\n if not allclose(tensor1, tensor2):\n raise ValueError(\"Tensors have different values.\")\n print(\"Both tensor lists match.\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, we will present two approaches to compute such Jacobians with BackPACK.\n\nYou can imagine the first one as carrying out the for-loop over :code:`N`\nparallel, and the second one as carrying out both for loops over :code:`N, C`\nin parallel. The first approach relies on a first-order extension, the second\none on a second-order extension. This means that while the first approach\nworks on quite general graphs, for the second one to work your graph must be\nfully BackPACK-compatible.\n\n## With BackPACK's :code:`BatchGrad`\n\nAs described in the introduction, BackPACK's :code:`BatchGrad` extension can\ncompute Jacobians of batched scalars. We can therefore compute the\nderivatives for :code:`fs[:, c]` in one iteration, parallelizing the Jacobian\ncomputation over the batch axis. For the full Jacobian, this requires\n:code:`C` backpropagations, hence we need to tell both :code:`autograd` and\nBackPACK to retain the graph.\n\nLet's do that in code, and check the result:\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"manual_seed(0)\nmodel, X = setup()\n\nmodel = extend(model)\n\nfs = model(X)\nbackpack_first_jacobians = [zeros(fs.shape + p.shape) for p in model.parameters()]\n\nfor c in range(C):\n with backpack(extensions.BatchGrad(), retain_graph=True):\n f = fs[:, c].sum()\n f.backward(retain_graph=True)\n\n for param_idx, param in enumerate(model.parameters()):\n backpack_first_jacobians[param_idx][:, c, :] = param.grad_batch\n\nprint(\"Comparing batched Jacobian from autograd with BackPACK (via BatchGrad):\")\ncompare_tensor_lists(autograd_jacobians, backpack_first_jacobians)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## With BackPACK's :code:`SqrtGGNExact`\n\nThe second approach uses BackPACK's :code:`SqrtGGNExact` second-order\nextension. It computes the matrix square root of the GGN/Fisher.\n\nThis approach uses that after feeding :code:`fs` through a square loss with\n:code:`reduction='sum'`, the GGN's square root is the desired Jacobian up to\na normalization factor of \u221a2 (to find out more, read Section 2 of [[Dangel,\n2021]](https://arxiv.org/abs/2106.02624)), and a transposition due to\nBackPACK's internals.\n\nLike that, we get the Jacobian in a single backward pass and don't have to\nretain the graph:\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"manual_seed(0)\nmodel, X = setup()\n\nmodel = extend(model)\nloss_func = extend(MSELoss(reduction=\"sum\"))\n\nfs = model(X)\nfs_labels = zeros_like(fs) # can contain arbitrary values.\nloss = loss_func(fs, fs_labels)\n\nwith backpack(extensions.SqrtGGNExact()):\n loss.backward()\n\nbackpack_second_jacobians = [\n param.sqrt_ggn_exact.transpose(0, 1) / sqrt(2) for param in model.parameters()\n]\n\nprint(\"Comparing batched Jacobian from autograd with BackPACK (via SqrtGGNExact):\")\ncompare_tensor_lists(autograd_jacobians, backpack_second_jacobians)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.9"
}
},
"nbformat": 4,
"nbformat_minor": 0
}