Search code examples
pythonpytorchin-place

Does inplace matter when we return ReLU(x)


Is there a difference between the following two classes? I know what inplace is (you don't need to do x = function(x) but only function(x) to modify x if inplace is True). But here because we return self.conv(x), it should not matter, right?

class ConvBlock(nn.Module):

    def __init__(
        self,
        in_channels,
        out_channels,
        down=True,
        use_act=True,
        **kwargs
        ):
        super().__init__()
        self.conv = nn.Sequential((nn.Conv2d(in_channels, out_channels,
                                  padding_mode='reflect',
                                  **kwargs) if down else nn.ConvTranspose2d(in_channels,
                                  out_channels, **kwargs)),
                                  nn.InstanceNorm2d(out_channels),
                                  (nn.ReLU() if use_act else nn.Identity()))

    def forward(self, x):
        return self.conv(x)


class ConvBlockInplace(nn.Module):

    def __init__(
        self,
        in_channels,
        out_channels,
        down=True,
        use_act=True,
        **kwargs
        ):
        super().__init__()
        self.conv = nn.Sequential((nn.Conv2d(in_channels, out_channels,
                                  padding_mode='reflect',
                                  **kwargs) if down else nn.ConvTranspose2d(in_channels,
                                  out_channels, **kwargs)),
                                  nn.InstanceNorm2d(out_channels),
                                  (nn.ReLU(inplace=True) if use_act else nn.Identity()))

    def forward(self, x):
        return self.conv(x)

Solution

  • The inplace operations do the exact amount of computations. However, there are less memory accesses, if your task is memory bound. Then, it would matter.


    I used the ptflops flops counter to generate the following statistics

    ConvBlock(
      0.0 M, 100.000% Params, 0.015 GMac, 100.000% MACs, 
      (conv): Sequential(
        0.0 M, 100.000% Params, 0.015 GMac, 100.000% MACs, 
        (0): Conv2d(0.0 M, 100.000% Params, 0.014 GMac, 93.333% MACs, 3, 10, kernel_size=(3, 3), stride=(1, 1), padding_mode=reflect)
        (1): InstanceNorm2d(0.0 M, 0.000% Params, 0.0 GMac, 3.333% MACs, 10, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (2): ReLU(0.0 M, 0.000% Params, 0.0 GMac, 3.333% MACs, )
      )
    )
    Computational complexity:       0.01 GMac
    Number of parameters:           280     
    Warning: module ConvBlockInplace is treated as a zero-op.
    ConvBlockInplace(
      0.0 M, 100.000% Params, 0.015 GMac, 100.000% MACs, 
      (conv): Sequential(
        0.0 M, 100.000% Params, 0.015 GMac, 100.000% MACs, 
        (0): Conv2d(0.0 M, 100.000% Params, 0.014 GMac, 93.333% MACs, 3, 10, kernel_size=(3, 3), stride=(1, 1), padding_mode=reflect)
        (1): InstanceNorm2d(0.0 M, 0.000% Params, 0.0 GMac, 3.333% MACs, 10, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (2): ReLU(0.0 M, 0.000% Params, 0.0 GMac, 3.333% MACs, inplace=True)
      )
    )
    Computational complexity:       0.01 GMac
    Number of parameters:           280