Search code examples

What's the correct way of expressing Residual Block with forward function of pytorch?

AFAIK there are 2 ways to express ResNet Block in pytorch:

  • Copy the input in the beginning, modify the input in the process, add the copy in the end.
  • Preserve the input in the beginning, create new variable in the process, add the input in the end.

Which leads to 2 kinds of code:

def forward(self, x):
    y = x
    x = self.conv1(x)
    x = self.norm1(x)
    x = self.act1(x)
    x = self.conv2(x)
    x = self.norm2(x)
    x += y
    x = self.act2(x)
    return x
def forward(self, x):
    y = self.conv1(x)
    y = self.norm1(y)
    y = self.act1(y)
    y = self.conv2(y)
    y = self.norm2(y)
    y += x
    y = self.act2(y)
    return y

Are they identical? Which one is preferred? Why?


  • It doesn't matter so long as the you retain some reference to the input.

    At a high level, you are trying to compute output = activation(input + f(input))

    Both methods shown accomplish this. As long as you don't lose the input reference or change input through an in-place operation, you should be fine.

    For what it's worth, I would separate out the residual connection and the sub-block just for clarity:

    class Block(nn.Module):
        def __init__(self, ...):
            self.conv1 = ...
            self.norm1 = ...
            self.act = ...
            self.conv2 = ...
            self.norm2 = ...
        def forward(self, x):
            x = self.conv1(x)
            x = self.norm1(x)
            x = self.act(x)
            x = self.conv2(x)
            x = self.norm2(x)
            return x
    class ResBlock(nn.Module):
        def __init__(self, block):
            self.block = block
            self.act = ...
        def forward(self, x):
            return self.act(x + self.block(x))