Search code examples
pythonlistpytorchpytorch-dataloaderpytorch-geometric

Is there a way to use list of indices to simultaneously access the modules of nn.ModuleList in python?


Is there a way to use list of indices to simultaneously access the modules of nn.ModuleList in python?

I am working with pytorch ModuleList as described below,

decision_modules = nn.ModuleList([nn.Linear(768, 768) for i in range(10)])

Our input data is of the shape x=torch.rand(32,768). Here 32 is the batch size and 768 is the feature dimension.

Now for each input data point in a minibatch of 32 datapoints, we want to select 4 decision modules from the list of decision_modules. The 4 decision engines from decision_engine are selected using an index list as explained below.

I have a index matrix of dimensions ind. The ind matrix is of dimension torch.randint(0,10,(32,4)).

I want to us a solution without use of loops as loops slows down the xecution significantly.

But the following code throws and error.

import torch
import torch.nn as nn

linears = nn.ModuleList([nn.Linear(768, 768) for i in range(10)])
ind=torch.randint(0,10,(32,4))
input=torch.rand(32,768)

out=linears[ind](input)

The following error was observed

File ~\AppData\Local\Programs\Python\Python312\Lib\site-packages\torch\nn\modules\container.py:334, in ModuleList.getitem(self, idx) 332 return self.class(list(self._modules.values())[idx]) 333 else: --> 334 return self._modules[self._get_abs_string_index(idx)]

File ~\AppData\Local\Programs\Python\Python312\Lib\site-packages\torch\nn\modules\container.py:314, in ModuleList._get_abs_string_index(self, idx) 312 def _get_abs_string_index(self, idx): 313 """Get the absolute index for the list of modules.""" --> 314 idx = operator.index(idx) 315 if not (-len(self) <= idx < len(self)): 316 raise IndexError(f"index {idx} is out of range")

TypeError: only integer tensors of a single element can be converted to an index

The expected output shape is (32,4,768).

Any help will be highly useful.


Solution

  • The ind tensor is of size (bs, n_decisions), which means we're choosing a different set of experts for each item in the batch.

    With this setup, the most efficient way to compute the output is to compute all experts for all batch items, then gather the desired choices after. This will be more performant in GPU compared to looping over the individual experts.

    Since we're looking at a linear layer, you can compute all the experts using a single linear layer of size n_experts * dim.

    d_in = 768
    n_experts = 10
    bs = 32
    n_choice = 4
    
    # Create a single large linear layer
    fused_linear = nn.Linear(d_in, d_in * n_experts)
    
    indices = torch.randint(0, n_experts, (bs, n_choice))
    x = torch.randn(bs, d_in)
    
    # Forward pass through the fused layer
    y = fused_linear(x)  # Shape: [bs, d_in * n_experts]
    
    # Reshape to separate the experts dimension
    ys = y.reshape(bs, n_experts, d_in)  # Shape: [bs, n_experts, d_in]
    
    # Gather the chosen experts
    ys = torch.gather(ys, 1, indices.unsqueeze(-1).expand(-1, -1, d_in))
    

    The output ys will be of shape (bs, n_choice, d_in)