As the title suggests, I am trying to understand how the functionality of these two functions as forward hooks in PyTorch? I see that regisfter_module_forward_hook adds a global state, which I am assuming means that there is one function for all forward hooks. Is that so, or how does its functionality differ from more commonly used register_forward_hook?
I am ultimately writing to compute the same statistical information from all of the layers of a given network, so the function to be used as a hook is the same across all layers. Is the latter a better choice ?
I haven't tried using them yet since I am trying to figure out which one is more adaptable for my case.
I was just trying to figure out the same question and found your question when Googling for it.
From some digging:
register_forward_hook
was added in this PR 7 yearsregister_module_forward_hook
was added 3 years ago in this PRIt appears former needs to be set on a per-module basis, while later is a global hook you set once to run for every module.
test_fwd = nn.modules.module.register_module_forward_hook(lambda *args: fw_hook(1, *args))
Looking at blame for register_module_forward_hook
shows this relevant issue with more details from 3 months ago.
It sounds like the latter is the better choice for your case. In particular, considering comments from latest commit, as it makes it compatible with context managers.
For example you can use it to compute per-example activation norms on every layer by using a context manager like this
@contextmanager
def module_hook(hook: Callable):
handle = nn.modules.module.register_module_forward_hook(hook, always_call=True)
yield
handle.remove()
def compute_norms(layer: nn.Module, inputs: Tuple[torch.Tensor], _output: torch.Tensor):
A = inputs[0].detach()
layer.norms2 = (A * A).sum(dim=1)
with module_hook(compute_norms):
outputs = model(data)
print("layer", "norms squared")
for name, layer in model.named_modules():
if not name:
continue
print(f"{name:20s}: {layer.norms2.cpu().numpy()}")
Full code from colab
from contextlib import contextmanager
from typing import Callable, Tuple
import torch
import torch.nn as nn
import numpy as np
device = 'cuda' if torch.cuda.is_available() else 'cpu'
data = torch.tensor([[1., 0.], [1., 1.]]).to(device)
bs = data.shape[0] # batch size
def simple_model(d, num_layers):
"""Creates simple linear neural network initialized to 2*identity"""
layers = []
for i in range(num_layers):
layer = nn.Linear(d, d, bias=False)
layer.weight.data.copy_(2 * torch.eye(d))
layers.append(layer)
return torch.nn.Sequential(*layers)
norms = [torch.zeros(bs).to(device)]
def compute_norms(layer: nn.Module, inputs: Tuple[torch.Tensor], _output: torch.Tensor):
assert len(inputs) == 1, "multi-input layer??"
A = inputs[0].detach()
layer.norms2 = (A * A).sum(dim=1)
model = simple_model(2, 3).to(device)
@contextmanager
def module_hook(hook: Callable):
handle = nn.modules.module.register_module_forward_hook(hook, always_call=True)
yield
handle.remove()
with module_hook(compute_norms):
outputs = model(data)
np.testing.assert_allclose(model[0].norms2.cpu(), [1, 2])
np.testing.assert_allclose(model[1].norms2.cpu(), [4, 8])
np.testing.assert_allclose(model[2].norms2.cpu(), [16, 32])
print(f"{'layer':20s}: {'norms squared'}")
for name, layer in model.named_modules():
if not name:
continue
print(f"{name:20s}: {layer.norms2.cpu().numpy()}")
# print(name, layer.norms2)
assert not torch.nn.modules.module._global_forward_hooks, "Some hooks remain"