Search code examples
pythonneural-networkpytorchtensordifference

Difference between torch.flatten() and nn.Flatten()


What are the differences between torch.flatten() and torch.nn.Flatten()?


Solution

  • Flattening is available in three forms in PyTorch

    • As a tensor method (oop style) torch.Tensor.flatten applied directly on a tensor: x.flatten().

    • As a function (functional form) torch.flatten applied as: torch.flatten(x).

    • As a module (layer nn.Module) nn.Flatten(). Generally used in a model definition.

    All three are identical and share the same implementation, the only difference being nn.Flatten has start_dim set to 1 by default to avoid flattening the first axis (usually the batch axis). While the other two flatten from axis=0 to axis=-1 - i.e. the entire tensor - if no arguments are given.