Is there a way to use list of indices to simultaneously access the modules of
nn.ModuleList
in python?
I am working with pytorch ModuleList
as described below,
decision_modules = nn.ModuleList([nn.Linear(768, 768) for i in range(10)])
Our input data is of the shape x=torch.rand(32,768)
. Here 32
is the batch size and 768
is the feature dimension.
Now for each input data point in a minibatch of 32
datapoints, we want to select 4
decision modules from the list of decision_modules
. The 4
decision engines from decision_engine
are selected using an index list as explained below.
I have a index matrix of dimensions ind
. The ind
matrix is of dimension torch.randint(0,10,(32,4))
.
I want to us a solution without use of loops as loops slows down the xecution significantly.
But the following code throws and error.
import torch
import torch.nn as nn
linears = nn.ModuleList([nn.Linear(768, 768) for i in range(10)])
ind=torch.randint(0,10,(32,4))
input=torch.rand(32,768)
out=linears[ind](input)
The following error was observed
File ~\AppData\Local\Programs\Python\Python312\Lib\site-packages\torch\nn\modules\container.py:334, in ModuleList.getitem(self, idx) 332 return self.class(list(self._modules.values())[idx]) 333 else: --> 334 return self._modules[self._get_abs_string_index(idx)]
File ~\AppData\Local\Programs\Python\Python312\Lib\site-packages\torch\nn\modules\container.py:314, in ModuleList._get_abs_string_index(self, idx) 312 def _get_abs_string_index(self, idx): 313 """Get the absolute index for the list of modules.""" --> 314 idx = operator.index(idx) 315 if not (-len(self) <= idx < len(self)): 316 raise IndexError(f"index {idx} is out of range")
TypeError: only integer tensors of a single element can be converted to an index
The expected output shape is (32,4,768).
Any help will be highly useful.
The ind
tensor is of size (bs, n_decisions)
, which means we're choosing a different set of experts for each item in the batch.
With this setup, the most efficient way to compute the output is to compute all experts for all batch items, then gather the desired choices after. This will be more performant in GPU compared to looping over the individual experts.
Since we're looking at a linear layer, you can compute all the experts using a single linear layer of size n_experts * dim
.
d_in = 768
n_experts = 10
bs = 32
n_choice = 4
# Create a single large linear layer
fused_linear = nn.Linear(d_in, d_in * n_experts)
indices = torch.randint(0, n_experts, (bs, n_choice))
x = torch.randn(bs, d_in)
# Forward pass through the fused layer
y = fused_linear(x) # Shape: [bs, d_in * n_experts]
# Reshape to separate the experts dimension
ys = y.reshape(bs, n_experts, d_in) # Shape: [bs, n_experts, d_in]
# Gather the chosen experts
ys = torch.gather(ys, 1, indices.unsqueeze(-1).expand(-1, -1, d_in))
The output ys
will be of shape (bs, n_choice, d_in)