Search code examples
deep-learningpytorchneural-network

What is the difference between register_forward_hook and register_module_forward_hook in PyTorch?


As the title suggests, I am trying to understand how the functionality of these two functions as forward hooks in PyTorch? I see that regisfter_module_forward_hook adds a global state, which I am assuming means that there is one function for all forward hooks. Is that so, or how does its functionality differ from more commonly used register_forward_hook?

I am ultimately writing to compute the same statistical information from all of the layers of a given network, so the function to be used as a hook is the same across all layers. Is the latter a better choice ?

I haven't tried using them yet since I am trying to figure out which one is more adaptable for my case.


Solution

  • I was just trying to figure out the same question and found your question when Googling for it.

    From some digging:

    • register_forward_hook was added in this PR 7 years
    • register_module_forward_hook was added 3 years ago in this PR

    It appears former needs to be set on a per-module basis, while later is a global hook you set once to run for every module.

     test_fwd = nn.modules.module.register_module_forward_hook(lambda *args: fw_hook(1, *args))
    

    Looking at blame for register_module_forward_hook shows this relevant issue with more details from 3 months ago.

    It sounds like the latter is the better choice for your case. In particular, considering comments from latest commit, as it makes it compatible with context managers.

    For example you can use it to compute per-example activation norms on every layer by using a context manager like this

    @contextmanager
    def module_hook(hook: Callable):
        handle = nn.modules.module.register_module_forward_hook(hook, always_call=True)
        yield
        handle.remove()
    
    def compute_norms(layer: nn.Module, inputs: Tuple[torch.Tensor], _output: torch.Tensor):
        A = inputs[0].detach()
        layer.norms2 = (A * A).sum(dim=1)
    
    with module_hook(compute_norms):
        outputs = model(data)
    
    print("layer", "norms squared")
    for name, layer in model.named_modules():
        if not name:
            continue
        print(f"{name:20s}: {layer.norms2.cpu().numpy()}")
    
    

    enter image description here

    Full code from colab

    from contextlib import contextmanager
    from typing import Callable, Tuple
    
    import torch
    import torch.nn as nn
    
    import numpy as np
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    data = torch.tensor([[1., 0.], [1., 1.]]).to(device)
    bs = data.shape[0]  # batch size
    
    def simple_model(d, num_layers):
        """Creates simple linear neural network initialized to 2*identity"""
        layers = []
        for i in range(num_layers):
            layer = nn.Linear(d, d, bias=False)
            layer.weight.data.copy_(2 * torch.eye(d))
            layers.append(layer)
        return torch.nn.Sequential(*layers)
    
    norms = [torch.zeros(bs).to(device)]
    
    def compute_norms(layer: nn.Module, inputs: Tuple[torch.Tensor], _output: torch.Tensor):
        assert len(inputs) == 1, "multi-input layer??"
        A = inputs[0].detach()
        layer.norms2 = (A * A).sum(dim=1)
    
    model = simple_model(2, 3).to(device)
    
    @contextmanager
    def module_hook(hook: Callable):
        handle = nn.modules.module.register_module_forward_hook(hook, always_call=True)
        yield
        handle.remove()
    
    with module_hook(compute_norms):
        outputs = model(data)
    
    np.testing.assert_allclose(model[0].norms2.cpu(), [1, 2])
    np.testing.assert_allclose(model[1].norms2.cpu(), [4, 8])
    np.testing.assert_allclose(model[2].norms2.cpu(), [16, 32])
    
    print(f"{'layer':20s}: {'norms squared'}")
    for name, layer in model.named_modules():
        if not name:
            continue
        print(f"{name:20s}: {layer.norms2.cpu().numpy()}")
    #     print(name, layer.norms2)
    
    assert not torch.nn.modules.module._global_forward_hooks, "Some hooks remain"