I am making a novel type of fully connected layer that operates with channels analogously to a 2D convolutional layer. Namely, each output channel is the sum of fully connected layers applied to the input channels. I have a fully working prototype below. But it is slow because of the for loops in the forward method. Theoretically all of these operations can happen in parallel. How can I achieve this? Thank you.
# assumes shape of input is (BATCH X channels X feature_dim)
class Channelwise_Linear(nn.Module):
def __init__(self, in_channels, out_channels, in_feature_dim, out_feature_dim, drop_rate):
super(Channelwise_Linear, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.in_feature_dim = in_feature_dim
self.out_feature_dim = out_feature_dim
self.linear_layers = []
self.bn_layers = []
self.activation_layers = []
self.do_layers = []
for out_channel in range(out_channels):
temp_linear = nn.ModuleList()
temp_bn = nn.ModuleList()
temp_AL = nn.ModuleList()
temp_do = nn.ModuleList()
for in_channel in range(in_channels):
temp_linear.append(nn.Linear(in_feature_dim, out_feature_dim))
temp_bn.append(nn.BatchNorm1d(out_feature_dim))
temp_AL.append(nn.LeakyReLU())
temp_do.append(nn.Dropout1d(drop_rate))
self.linear_layers.append(temp_linear)
self.bn_layers.append(temp_bn)
self.activation_layers.append(temp_AL)
self.do_layers.append(temp_do)
def forward(self, x):
x = torch.transpose(x,0,1) # now channel is first
output_list = []
for out_channel in range(self.out_channels):
temp = []
for in_channel in range(self.in_channels):
intermediate = self.linear_layers[out_channel][in_channel](x[in_channel])
intermediate = self.bn_layers[out_channel][in_channel](intermediate)
intermediate = self.activation_layers[out_channel][in_channel](intermediate)
intermediate = self.do_layers[out_channel][in_channel](intermediate)
temp.append(intermediate)
temp = torch.sum(torch.stack(temp), axis=0) # each out_channel is the sum of the output of the nn.Linear(in_channels), like Conv2D
output_list.append(temp)
return torch.transpose(torch.stack(output_list), 0, 1) # switch the batch back to first dim
You're correct; the for loop can make your forward
method quite slow. Here's a more efficient way of implementing your model with vectorized computation:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class Channelwise_Linear(nn.Module):
def __init__(self, in_channels, out_channels, in_feature_dim, out_feature_dim, drop_rate):
super(Channelwise_Linear, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
# Weights and biases for linear layers for all in-channels and out-channels.
self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels, out_feature_dim, in_feature_dim))
self.bias = nn.Parameter(torch.Tensor(out_channels, in_channels, out_feature_dim))
self.batch_norm = nn.BatchNorm1d(out_channels)
self.dropout = nn.Dropout(drop_rate)
self.reset_parameters()
def reset_parameters(self):
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
nn.init.uniform_(self.bias, -bound, bound)
def forward(self, x):
# Reshape x and weight to apply linear layer to each pair of out_channel and in_channel
out = torch.matmul(self.weight, x.view(x.size(0), 1, self.in_channels, -1).unsqueeze(-1)).squeeze(-1) + self.bias
# Flatten out for batch normalization and dropout
out = out.view(x.size(0), self.out_channels, self.in_channels * out.size(-1))
out = out.sum(2)
out = self.batch_norm(out)
out = F.leaky_relu(out)
out = self.dropout(out)
return out
I created the weight and bias for all pairs of output channels and input channels together, rather than creating them separately. I then reshape the input x and the weight such that the linear layer can be applied to each pair of output channel and input channel. The reshaped output can then be passed through batch normalization, activation, and dropout layers. I then reshape it back to its original shape and sum along the input channel dimension.
This implementation should be significantly faster than the original one with for loops. However, it uses a bit more memory to store intermediate computations, which might be a problem if your model is very large.
Note in the provided code, the reset_parameters()
function initializes the weights using Kaiming Uniform initialization and the biases with a uniform distribution. This kind of initialization is commonly used, especially when ReLU (or its variants, like LeakyReLU used here) is the activation function, because it helps to keep the scale of the input variance constant across layers, reducing the risk of vanishing/exploding gradients.
It's good practice to define a separate method like reset_parameters() for this purpose, because it provides flexibility to change the initialization method later without modifying the constructor or forward methods. Also, if you want to reinitialize the model at some point (like if you want to restart training from scratch), you can easily do so by calling this method.