Search code examples
pythonmachine-learningpytorchparallel-processinghpc

Paralellizing a natively single-batch pytorch model


Is it possible to parallelize a (natively) single batch model?

Usually parallelization is done via the torch.bmm (batched matrix multiplication) in stead of the torch.matmul and fixing one dimension specifically for the batches. However this is not available for example for the torch.tensordot function.

So if one has such a model, is it possible to compute each gradient of the batch in parallel? Ideally, the parallelization should work with both training and inference.

A code example:

import torch
import torch.nn as nn

class LinearMultidimModel(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LinearMultidimModel, self).__init__()
        self.weight = nn.Parameter(torch.randn(input_dim, hidden_dim, output_dim))
        self.bias = nn.Parameter(torch.randn(output_dim))

    def forward(self, x):
        # Using torch.tensordot to perform the linear transformation
        out = torch.tensordot(x, self.weight, dims=[[0,1],[0,1]]) + self.bias
        return out

# Example usage
input_dim = 3
hidden_dim=2
output_dim = 1
model = LinearMultidimModel(input_dim, output_dim)

# Dummy input
x = torch.randn(input_dim, hidden_dim)# But what if I want to put in a batch, torch.randn(batch_size, input_dim, hidden_dim)?
output = model(x)
print(output)

Keep in mind that if there is no hidden_dim, it natively does the parallelization, one can entirely remove hidden_dim and get a result with

x = torch.randn(5, input_dim).

I've tried using Einsum, but that works for a fixed amount of hidden dimensions...


Solution

  • You can use torch.vmap for this exact purpose

    class LinearMultidimModel(nn.Module):
        def __init__(self, input_dim, output_dim):
            super(LinearMultidimModel, self).__init__()
            self.weight = nn.Parameter(torch.randn(input_dim, hidden_dim, output_dim))
            self.bias = nn.Parameter(torch.randn(output_dim))
    
        def forward(self, x):
            # Using torch.tensordot to perform the linear transformation
            out = torch.tensordot(x, self.weight, dims=[[0,1],[0,1]]) + self.bias
            return out
        
    input_dim = 3
    hidden_dim=2
    output_dim = 1
    model = LinearMultidimModel(input_dim, output_dim)
    
    # create input with batch dimension
    batch_size = 8
    x = torch.randn(batch_size, input_dim, hidden_dim)
    
    # example unbatched inference
    y1 = torch.stack([model(i) for i in x])
    
    # vmap model to make it batched
    model_batched = torch.func.vmap(model)
    
    # batch inference
    y2 = model_batched(x)
    
    # assert outputs are the same
    assert torch.allclose(y1, y2)