Search code examples
pythonmachine-learningpytorchautograd

Can my PyTorch forward function do additional operations?


Typically a forward function strings together a bunch of layers and returns the output of the last one. Can I do some additional processing after that last layer before returning? For example, some scalar multiplication and reshaping via .view?

I know that the autograd somehow figures out gradients. So I don’t know if my additional processing will somehow screw that up. Thanks.


Solution

  • tracks the gradients via the computational graph of the tensors, not through the functions. As long as your tensors has requires_grad=True property and their grad is not None you can do (almost) whatever you like and still be able to backprop.
    As long as you are using pytorch's operations (e.g., those listed in here and here) you should be okay.

    For more info see this.

    For example (taken from torchvision's VGG implementation):

    class VGG(nn.Module):
    
        def __init__(self, features, num_classes=1000, init_weights=True):
            super(VGG, self).__init__()
            #  ...
    
        def forward(self, x):
            x = self.features(x)
            x = self.avgpool(x)
            x = torch.flatten(x, 1)  # <-- what you were asking about
            x = self.classifier(x)
            return x
    

    A more complex example can be seen in torchvision's implementation of ResNet:

    class Bottleneck(nn.Module):
        def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                     base_width=64, dilation=1, norm_layer=None):
            super(Bottleneck, self).__init__()
            # ...
    
        def forward(self, x):
            identity = x
    
            out = self.conv1(x)
            out = self.bn1(out)
            out = self.relu(out)
    
            out = self.conv2(out)
            out = self.bn2(out)
            out = self.relu(out)
    
            out = self.conv3(out)
            out = self.bn3(out)
    
            if self.downsample is not None:    # <-- conditional execution!
                identity = self.downsample(x)
    
            out += identity  # <-- inplace operations
            out = self.relu(out)
    
            return out