Search code examples
pytorch

Why do code authors re-use the variable x in Pytorch Forward


A typical forward declaration in a Pytorch model looks like this:

def forward(self, x):
        x = self.conv1(x)
        x = F.relu(F.max_pool2d(x, kernel_size = 2))
        x = self.drop1(x)
        return x

It seems universally used standard. But I have been able to get the code to work by creating new variables.

def forward(self, x):
        a = self.conv1(x)
        b = F.relu(F.max_pool2d(a, kernel_size = 2))
        c = self.drop1(b)
        return c

I cannot find anywhere an actual explanation. Can someone explain why the reused x version is preferred?


Solution

  • A variable that is still in use (e.g. a) can't be freed by Python's garbage collection mechanism as long as there are still references to it.

    The version with a,b,c variables can lead to higher peak memory usage. By using the same name, x, after a couple of lines there are no remaining references to the output of self.conv1(x), and so reference counts go to zero and the memory can be freed.

    Additionally, re-using the same variable can make it easier to quickly re-order operations or comment out some layers.

    Note that peak memory will typically be two lines of code not one. In your first example, memory for the output of F.relu will be allocated and assigned to x before the memory from the output of self.conv1 can be freed... but unlike the second example can be freed before the call to self.drop1.