Search code examples
pythondeep-learningpytorch

How to vectorize the forward method of this pytorch block?


I am making a novel type of fully connected layer that operates with channels analogously to a 2D convolutional layer. Namely, each output channel is the sum of fully connected layers applied to the input channels. I have a fully working prototype below. But it is slow because of the for loops in the forward method. Theoretically all of these operations can happen in parallel. How can I achieve this? Thank you.

# assumes shape of input is (BATCH X channels X feature_dim)
class Channelwise_Linear(nn.Module):
  def __init__(self, in_channels, out_channels, in_feature_dim, out_feature_dim, drop_rate):
      super(Channelwise_Linear, self).__init__()
      self.in_channels = in_channels
      self.out_channels = out_channels
      self.in_feature_dim = in_feature_dim
      self.out_feature_dim = out_feature_dim

      self.linear_layers = []
      self.bn_layers = []
      self.activation_layers = []
      self.do_layers = []
      for out_channel in range(out_channels):
          temp_linear = nn.ModuleList()
          temp_bn = nn.ModuleList()
          temp_AL = nn.ModuleList()
          temp_do = nn.ModuleList()
          for in_channel in range(in_channels):
            temp_linear.append(nn.Linear(in_feature_dim, out_feature_dim))
            temp_bn.append(nn.BatchNorm1d(out_feature_dim))
            temp_AL.append(nn.LeakyReLU())
            temp_do.append(nn.Dropout1d(drop_rate))
          self.linear_layers.append(temp_linear)
          self.bn_layers.append(temp_bn)
          self.activation_layers.append(temp_AL)
          self.do_layers.append(temp_do)

  def forward(self, x):

      x = torch.transpose(x,0,1) # now channel is first

      output_list = []
      for out_channel in range(self.out_channels):
          temp = []
          for in_channel in range(self.in_channels):
            intermediate = self.linear_layers[out_channel][in_channel](x[in_channel])
            intermediate = self.bn_layers[out_channel][in_channel](intermediate)
            intermediate = self.activation_layers[out_channel][in_channel](intermediate)
            intermediate = self.do_layers[out_channel][in_channel](intermediate)
            temp.append(intermediate)
          temp = torch.sum(torch.stack(temp), axis=0) # each out_channel is the sum of the output of the nn.Linear(in_channels), like Conv2D

          output_list.append(temp)

      return torch.transpose(torch.stack(output_list), 0, 1) # switch the batch back to first dim

Solution

  • You're correct; the for loop can make your forward method quite slow. Here's a more efficient way of implementing your model with vectorized computation:

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import math
    
    class Channelwise_Linear(nn.Module):
        def __init__(self, in_channels, out_channels, in_feature_dim, out_feature_dim, drop_rate):
            super(Channelwise_Linear, self).__init__()
    
            self.in_channels = in_channels
            self.out_channels = out_channels
    
            # Weights and biases for linear layers for all in-channels and out-channels.
            self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels, out_feature_dim, in_feature_dim))
            self.bias = nn.Parameter(torch.Tensor(out_channels, in_channels, out_feature_dim))
    
            self.batch_norm = nn.BatchNorm1d(out_channels)
            self.dropout = nn.Dropout(drop_rate)
            self.reset_parameters()
    
        def reset_parameters(self):
            nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)
    
        def forward(self, x):
            # Reshape x and weight to apply linear layer to each pair of out_channel and in_channel
            out = torch.matmul(self.weight, x.view(x.size(0), 1, self.in_channels, -1).unsqueeze(-1)).squeeze(-1) + self.bias
    
            # Flatten out for batch normalization and dropout
            out = out.view(x.size(0), self.out_channels, self.in_channels * out.size(-1))
            out = out.sum(2)
            
            out = self.batch_norm(out)
            out = F.leaky_relu(out)
            out = self.dropout(out)
    
            return out
    
    

    I created the weight and bias for all pairs of output channels and input channels together, rather than creating them separately. I then reshape the input x and the weight such that the linear layer can be applied to each pair of output channel and input channel. The reshaped output can then be passed through batch normalization, activation, and dropout layers. I then reshape it back to its original shape and sum along the input channel dimension.

    This implementation should be significantly faster than the original one with for loops. However, it uses a bit more memory to store intermediate computations, which might be a problem if your model is very large.

    Note in the provided code, the reset_parameters() function initializes the weights using Kaiming Uniform initialization and the biases with a uniform distribution. This kind of initialization is commonly used, especially when ReLU (or its variants, like LeakyReLU used here) is the activation function, because it helps to keep the scale of the input variance constant across layers, reducing the risk of vanishing/exploding gradients.

    It's good practice to define a separate method like reset_parameters() for this purpose, because it provides flexibility to change the initialization method later without modifying the constructor or forward methods. Also, if you want to reinitialize the model at some point (like if you want to restart training from scratch), you can easily do so by calling this method.