Recurrent networks

There are two different approaches to using BackPACK with RNNs.

  1. Custom RNN with BackPACK custom modules: Build your RNN with custom modules provided by BackPACK without overwriting the forward pass. This approach is useful if you want to understand how BackPACK handles RNNs, or if you think building a container module that implicitly defines the forward pass is more elegant than coding up a forward pass.

  2. RNN with BackPACK’s converter: Automatically convert your model into a BackPACK-compatible architecture.

Note

RNNs are still an experimental feature. Always double-check your results, as done in this example! Open an issue if you encounter a bug to help us improve the support.

Not all extensions support RNNs (yet). Please create a feature request in the repository if the extension you need is not supported.

from pkg_resources import packaging

Let’s get the imports out of the way.

from torch import (
    _C,
    allclose,
    cat,
    device,
    int32,
    linspace,
    manual_seed,
    nn,
    randint,
    zeros_like,
)

from backpack import backpack, extend
from backpack.custom_module.graph_utils import BackpackTracer
from backpack.custom_module.permute import Permute
from backpack.custom_module.reduce_tuple import ReduceTuple
from backpack.extensions import BatchGrad, DiagGGNExact
from backpack.utils import TORCH_VERSION
from backpack.utils.examples import autograd_diag_ggn_exact

manual_seed(0)
DEVICE = device("cpu")  # Verification via autograd only works on CPU

Note

Due to #99413, we have to disable MKLDNN for PyTorch 2.0.1 to get the double-backward through LSTMs working.

if TORCH_VERSION == packaging.version.parse("2.0.1"):
    _C._set_mkldnn_enabled(False)

For this demo, we will use the Tolstoi Char RNN from DeepOBS. This network is trained on Leo Tolstoi’s War and Peace and learns to predict the next character.

class TolstoiCharRNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.batch_size = 8
        self.hidden_dim = 64
        self.num_layers = 2
        self.seq_len = 15
        self.vocab_size = 25

        self.embedding = nn.Embedding(
            num_embeddings=self.vocab_size, embedding_dim=self.hidden_dim
        )
        self.dropout = nn.Dropout(p=0.2)
        self.lstm = nn.LSTM(
            input_size=self.hidden_dim,
            hidden_size=self.hidden_dim,
            num_layers=self.num_layers,
            dropout=0.36,
            batch_first=True,
        )
        # deactivate redundant bias
        self.lstm.bias_ih_l0.data = zeros_like(self.lstm.bias_ih_l0)
        self.lstm.bias_ih_l1.data = zeros_like(self.lstm.bias_ih_l1)
        self.lstm.bias_ih_l0.requires_grad = False
        self.lstm.bias_ih_l1.requires_grad = False
        self.dense = nn.Linear(
            in_features=self.hidden_dim, out_features=self.vocab_size
        )

    def forward(self, x):
        x = self.embedding(x)
        x = self.dropout(x)
        x, _ = self.lstm(x)  # last return values are hidden states
        x = self.dropout(x)
        output = self.dense(x)
        output = output.permute(0, 2, 1)  # [N, T, C] → [N, C, T]
        return output

    def input_target_fn(self):
        input = randint(0, self.vocab_size, (self.batch_size, self.seq_len))
        # target is the input shifted by 1 in time axis
        target = cat(
            [
                randint(0, self.vocab_size, (self.batch_size, 1)),
                input[:, :-1],
            ],
            dim=1,
        )
        return input.to(DEVICE), target.to(DEVICE)

    def loss_fn(self) -> nn.Module:
        return nn.CrossEntropyLoss().to(DEVICE)


manual_seed(1)
tolstoi_char_rnn = TolstoiCharRNN().to(DEVICE).eval()
loss_function = extend(tolstoi_char_rnn.loss_fn())
x, y = tolstoi_char_rnn.input_target_fn()

Note that instead of the real data set, we will feed synthetic data to the network for simplicity. We also use the network in evaluation mode. This disables the Dropout layers and allows double-checking our results via torch.autograd.

Custom RNN with BackPACK custom modules

Second-order extensions only work if every node in the computation graph is an nn module that can be extended by BackPACK. The above RNN TolstoiCharRNN does not satisfy these conditions, because it has a multi-layer torch.nn.LSTM and implicitly uses the getitem() (for unpacking) and permute() functions in the forward() method.

To build RNN without overwriting the forward pass, BackPACK offers custom modules:

  1. ReduceTuple

  2. Permute

With the above modules, we can build a simple RNN as a container that implicitly defines the forward pass:

manual_seed(1)  # same seed as used to initialize `tolstoi_char_rnn`
tolstoi_char_rnn_custom = nn.Sequential(
    nn.Embedding(tolstoi_char_rnn.vocab_size, tolstoi_char_rnn.hidden_dim),
    nn.Dropout(p=0.2),
    nn.LSTM(tolstoi_char_rnn.hidden_dim, tolstoi_char_rnn.hidden_dim, batch_first=True),
    ReduceTuple(index=0),
    nn.Dropout(p=0.36),
    nn.LSTM(tolstoi_char_rnn.hidden_dim, tolstoi_char_rnn.hidden_dim, batch_first=True),
    ReduceTuple(index=0),
    nn.Dropout(p=0.2),
    nn.Linear(tolstoi_char_rnn.hidden_dim, tolstoi_char_rnn.vocab_size),
    Permute(0, 2, 1),
)
tolstoi_char_rnn_custom.eval().to(DEVICE)
Sequential(
  (0): Embedding(25, 64)
  (1): Dropout(p=0.2, inplace=False)
  (2): LSTM(64, 64, batch_first=True)
  (3): ReduceTuple()
  (4): Dropout(p=0.36, inplace=False)
  (5): LSTM(64, 64, batch_first=True)
  (6): ReduceTuple()
  (7): Dropout(p=0.2, inplace=False)
  (8): Linear(in_features=64, out_features=25, bias=True)
  (9): Permute()
)

Let’s check that both models have the same forward pass.

for name, p in tolstoi_char_rnn_custom.named_parameters():
    if "bias_ih_l" in name:
        # deactivate redundant bias
        p.data = zeros_like(p.data)
        p.requires_grad = False

match = allclose(tolstoi_char_rnn_custom(x), tolstoi_char_rnn(x))
print(f"Forward pass of custom model matches TolstoiCharRNN? {match}")

if not match:
    raise AssertionError("Forward passes don't match.")
Forward pass of custom model matches TolstoiCharRNN? True

We can extend our model and the loss function to compute BackPACK extensions.

tolstoi_char_rnn_custom = extend(tolstoi_char_rnn_custom)
loss = loss_function(tolstoi_char_rnn_custom(x), y)

with backpack(BatchGrad(), DiagGGNExact()):
    loss.backward()

for name, param in tolstoi_char_rnn_custom.named_parameters():
    if param.requires_grad:
        print(
            name,
            param.shape,
            param.grad_batch.shape,
            param.diag_ggn_exact.shape,
        )
0.weight torch.Size([25, 64]) torch.Size([8, 25, 64]) torch.Size([25, 64])
2.weight_ih_l0 torch.Size([256, 64]) torch.Size([8, 256, 64]) torch.Size([256, 64])
2.weight_hh_l0 torch.Size([256, 64]) torch.Size([8, 256, 64]) torch.Size([256, 64])
2.bias_hh_l0 torch.Size([256]) torch.Size([8, 256]) torch.Size([256])
5.weight_ih_l0 torch.Size([256, 64]) torch.Size([8, 256, 64]) torch.Size([256, 64])
5.weight_hh_l0 torch.Size([256, 64]) torch.Size([8, 256, 64]) torch.Size([256, 64])
5.bias_hh_l0 torch.Size([256]) torch.Size([8, 256]) torch.Size([256])
8.weight torch.Size([25, 64]) torch.Size([8, 25, 64]) torch.Size([25, 64])
8.bias torch.Size([25]) torch.Size([8, 25]) torch.Size([25])

Comparison of the GGN diagonal extension with torch.autograd:

Note

Computing the full GGN diagonal with PyTorch’s built-in automatic differentiation can be slow, depending on the number of parameters. To reduce run time, we only compare some elements of the diagonal.

trainable_params = [p for p in tolstoi_char_rnn_custom.parameters() if p.requires_grad]

diag_ggn_exact_vector = cat([p.diag_ggn_exact.flatten() for p in trainable_params])

num_params = sum(p.numel() for p in trainable_params)
num_to_compare = 10
idx_to_compare = linspace(0, num_params - 1, num_to_compare, device=DEVICE, dtype=int32)

diag_ggn_exact_to_compare = autograd_diag_ggn_exact(
    x, y, tolstoi_char_rnn_custom, loss_function, idx=idx_to_compare
)

print("Do the exact GGN diagonals match?")
for idx, element in zip(idx_to_compare, diag_ggn_exact_to_compare):
    match = allclose(element, diag_ggn_exact_vector[idx])
    print(
        f"Diagonal entry {idx:>8}: {match}:"
        + f"\t{element:.5e}, {diag_ggn_exact_vector[idx]:.5e}"
    )
    if not match:
        raise AssertionError("Exact GGN diagonals don't match!")
Do the exact GGN diagonals match?
Diagonal entry        0: True:  3.49859e-07, 3.49858e-07
Diagonal entry     7696: True:  1.52881e-07, 1.52881e-07
Diagonal entry    15393: True:  5.09857e-07, 5.09857e-07
Diagonal entry    23090: True:  6.00690e-09, 6.00689e-09
Diagonal entry    30787: True:  2.20773e-08, 2.20773e-08
Diagonal entry    38484: True:  1.05477e-08, 1.05477e-08
Diagonal entry    46181: True:  3.00283e-05, 3.00282e-05
Diagonal entry    53878: True:  3.29222e-09, 3.29222e-09
Diagonal entry    61575: True:  2.02606e-06, 2.02606e-06
Diagonal entry    69272: True:  4.29104e-02, 4.29104e-02

RNN with BackPACK’s converter

If you are not building an RNN through custom modules but for instance want to directly use the Tolstoi Char RNN, BackPACK offers a converter. It analyzes the model and tries to turn it into a compatible architecture. The result is a torch.fx.GraphModule that exclusively consists of modules.

Here, we demonstrate the converter on the above Tolstoi Char RNN. Let’s convert it while extend-ing the model:

# use BackPACK's converter to extend the model (turned off by default)
tolstoi_char_rnn = extend(tolstoi_char_rnn, use_converter=True)

To get an understanding what happened, we can inspect the model’s graph with the following helper function:

def print_table(module: nn.Module) -> None:
    """Prints a table of the module.

    Args:
        module: module to analyze
    """
    graph = BackpackTracer().trace(module)
    graph.print_tabular()


print_table(tolstoi_char_rnn)
opcode       name                 target               args                    kwargs
-----------  -------------------  -------------------  ----------------------  --------
placeholder  x                    x                    ()                      {}
call_module  embedding            embedding            (x,)                    {}
call_module  dropout0             dropout0             (embedding,)            {}
call_module  lstm_lstm_0          lstm.lstm_0          (dropout0,)             {}
call_module  lstm_reduce_tuple_0  lstm.reduce_tuple_0  (lstm_lstm_0,)          {}
call_module  lstm_dropout_0       lstm.dropout_0       (lstm_reduce_tuple_0,)  {}
call_module  lstm_lstm_1          lstm.lstm_1          (lstm_dropout_0,)       {}
call_module  reduce_tuple0        reduce_tuple0        (lstm_lstm_1,)          {}
call_module  reduce_tuple1        reduce_tuple1        (lstm_lstm_1,)          {}
call_module  dropout1             dropout1             (reduce_tuple0,)        {}
call_module  dense                dense                (dropout1,)             {}
call_module  permute0             permute0             (dense,)                {}
output       output               output               (permute0,)             {}

Note that the computation graph fully consists of modules (indicated by call_module in the first table column) such that BackPACK’s hooks can successfully backpropagate additional information for its second-order extensions (first-order extensions work, too).

First, let’s compare the forward pass with the custom module from the previous section to make sure the converter worked fine:

match = allclose(tolstoi_char_rnn_custom(x), tolstoi_char_rnn(x))
print(f"Forward pass of extended TolstoiCharRNN matches custom model? {match}")

if not match:
    raise AssertionError("Forward passes don't match.")
Forward pass of extended TolstoiCharRNN matches custom model? True

Now let’s verify that second-order extensions (GGN diagonal) are working:

output = tolstoi_char_rnn(x)
loss = loss_function(output, y)

with backpack(DiagGGNExact()):
    loss.backward()

for name, parameter in tolstoi_char_rnn.named_parameters():
    if parameter.requires_grad:
        print(f"{name}'s diag_ggn_exact: {parameter.diag_ggn_exact.shape}")

diag_ggn_exact_vector = cat(
    [
        p.diag_ggn_exact.flatten()
        for p in tolstoi_char_rnn.parameters()
        if p.requires_grad
    ]
)
embedding.weight's diag_ggn_exact: torch.Size([25, 64])
lstm.lstm_0.weight_ih_l0's diag_ggn_exact: torch.Size([256, 64])
lstm.lstm_0.weight_hh_l0's diag_ggn_exact: torch.Size([256, 64])
lstm.lstm_0.bias_hh_l0's diag_ggn_exact: torch.Size([256])
lstm.lstm_1.weight_ih_l0's diag_ggn_exact: torch.Size([256, 64])
lstm.lstm_1.weight_hh_l0's diag_ggn_exact: torch.Size([256, 64])
lstm.lstm_1.bias_hh_l0's diag_ggn_exact: torch.Size([256])
dense.weight's diag_ggn_exact: torch.Size([25, 64])
dense.bias's diag_ggn_exact: torch.Size([25])

Finally, we compare BackPACK’s GGN diagonal with torch.autograd:

Note

Computing the full GGN diagonal with PyTorch’s built-in automatic differentiation can be slow, depending on the number of parameters. To reduce run time, we only compare some elements of the diagonal.

num_params = sum(p.numel() for p in tolstoi_char_rnn.parameters() if p.requires_grad)
num_to_compare = 10
idx_to_compare = linspace(0, num_params - 1, num_to_compare, device=DEVICE, dtype=int32)

diag_ggn_exact_to_compare = autograd_diag_ggn_exact(
    x, y, tolstoi_char_rnn, loss_function, idx=idx_to_compare
)

print("Do the exact GGN diagonals match?")
for idx, element in zip(idx_to_compare, diag_ggn_exact_to_compare):
    match = allclose(element, diag_ggn_exact_vector[idx])
    print(
        f"Diagonal entry {idx:>8}: {match}:"
        + f"\t{element:.5e}, {diag_ggn_exact_vector[idx]:.5e}"
    )
    if not match:
        raise AssertionError("Exact GGN diagonals don't match!")
Do the exact GGN diagonals match?
Diagonal entry        0: True:  3.49859e-07, 3.49858e-07
Diagonal entry     7696: True:  1.52881e-07, 1.52881e-07
Diagonal entry    15393: True:  5.09857e-07, 5.09857e-07
Diagonal entry    23090: True:  6.00690e-09, 6.00689e-09
Diagonal entry    30787: True:  2.20773e-08, 2.20773e-08
Diagonal entry    38484: True:  1.05477e-08, 1.05477e-08
Diagonal entry    46181: True:  3.00283e-05, 3.00282e-05
Diagonal entry    53878: True:  3.29222e-09, 3.29222e-09
Diagonal entry    61575: True:  2.02606e-06, 2.02606e-06
Diagonal entry    69272: True:  4.29104e-02, 4.29104e-02

Total running time of the script: (0 minutes 5.089 seconds)

Gallery generated by Sphinx-Gallery