What are the differences between torch.flatten()
and torch.nn.Flatten()
?
Flattening is available in three forms in PyTorch
As a tensor method (oop style) torch.Tensor.flatten
applied directly on a tensor: x.flatten()
.
As a function (functional form) torch.flatten
applied as: torch.flatten(x)
.
As a module (layer nn.Module
) nn.Flatten()
. Generally used in a model definition.
All three are identical and share the same implementation, the only difference being nn.Flatten
has start_dim
set to 1
by default to avoid flattening the first axis (usually the batch axis). While the other two flatten from axis=0
to axis=-1
- i.e. the entire tensor - if no arguments are given.