Search code examples
pythonmachine-learningdeep-learningpytorchpytorch-lightning

Is there a `Split` equivalent to torch.nn.Sequential?


A sample code for a Sequential block is

self._encoder = nn.Sequential(
        # 1, 28, 28
        nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=3, padding=1),
        # 32, 10, 10 = 16, (1//3)(28 + 2 * 1 - 3) + 1, (1//3)(28 + 2*1 - 3) + 1
        nn.ReLU(True),
        nn.MaxPool2d(kernel_size=2, stride=2),
        # 32, 5, 5
        nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1),
        # 64, 3, 3
        nn.ReLU(True),
        nn.MaxPool2d(kernel_size=2, stride=1),
        # 64, 2, 2
)

Is there some construct like nn.Sequential that puts modules in it in parallel?

I would like to now define something like

self._mean_logvar_layers = nn.Parallel(
    nn.Conv2d(in_channels=64, out_channels=64, kernel_size=2, stride=1, padding=0),
    nn.Conv2d(in_channels=64, out_channels=64, kernel_size=2, stride=1, padding=0),
)

Whose output should be two pipes of data - one for each element in self._mean_logvar_layers which are then feedable to the rest of the network. Kind of like a multi-headed network.


My current implementation:

self._mean_layer = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=2, stride=1, padding=0)
self._logvar_layer = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=2, stride=1, padding=0)

and

def _encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    for i, layer in enumerate(self._encoder):
        x = layer(x)

    mean_output = self._mean_layer(x)
    logvar_output = self._logvar_layer(x)

    return mean_output, logvar_output

I would like to treat the parallel construct as a layer.

Is that doable in PyTorch?


Solution

  • Sequential split

    What you can do is create a Parallel module (though I would name it differently as it implies this code actually runs in parallel, probably Split would be a good name) like this:

    class Parallel(torch.nn.Module):
        def __init__(self, *modules: torch.nn.Module):
            super().__init__()
            self.modules = modules
    
        def forward(self, inputs):
            return [module(inputs) for module in self.modules]
    

    Now you can define it as you wanted:

    self._mean_logvar_layers = Parallel(
        nn.Conv2d(in_channels=64, out_channels=64, kernel_size=2, stride=1, padding=0),
        nn.Conv2d(in_channels=64, out_channels=64, kernel_size=2, stride=1, padding=0),
    )
    

    And use it like this:

    mean, logvar = self._mean_logvar_layers(x)
    

    One layer and split

    As suggested by @xdurch0 we could use a single layer and split across channels instead, using this module:

    class Split(torch.nn.Module):
        def __init__(self, module, parts: int, dim=1):
            super().__init__()
            self.parts
            self.dim = dim
            self.module = module
    
        def forward(self, inputs):
            output = self.module(inputs)
            chunk_size = output.shape[self.dim] // self.parts
            return torch.split(output, chunk_size, dim=self.dim)
    

    This inside your neural network (notice 128 channels, those will be split into 2 parts, each of size 64):

    self._mean_logvar_layers = Split(
        nn.Conv2d(in_channels=64, out_channels=128, kernel_size=2, stride=1, padding=0),
        parts=2,
    )
    

    And use it like previously:

    mean, logvar = self._mean_logvar_layers(x)
    

    Why this approach?

    Everything will be done in one swoop instead of sequentially, hence faster, but might be too wide if you don't have enough GPU memory.

    Can it work with Sequential?

    Yes, it is still a layer. But next layer has to work with tuple(torch.Tensor, torch.Tensor) as inputs.

    Sequential is also a layer, quite simple one, let's see forward:

    def forward(self, inp):
        for module in self:
            inp = module(inp)
        return inp
    

    It just passes output from previous model to the next and that's it.