Search code examples
pytorch

Parallel analog to torch.nn.Sequential container


Just wondering, why I can't find subj in torch.nn? nn.Sequential is pretty convinient, it allows to define networks in one place, clear and visual, but restricted to very simple ones! With parallel analog (and litle help of "identity" nodes for residual connections) it forms a complete method to construct any feedforward net combinatorial way. Am I missing something?


Solution

  • Well, maybe it shouldn't be in standard module collection, just because it can be defined really simple:

    class ParallelModule(nn.Sequential):
        def __init__(self, *args):
            super(ParallelModule, self).__init__( *args )
    
        def forward(self, input):
            output = []
            for module in self:
                output.append( module(input) )
            return torch.cat( output, dim=1 )
    

    Inheriting "Parallel" from "Sequential" is ideologically bad, but works well. Now one can define networks like pictured, with following code: Network image by torchviz:

    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.net = nn.Sequential(
                nn.Conv2d(  1, 32, 3, padding=1 ),        nn.ReLU(),
                nn.Conv2d( 32, 64, 3, padding=1 ),        nn.ReLU(),
                nn.MaxPool2d( 3, stride=2 ),              nn.Dropout2d( 0.25 ),
    
                ParallelModule(
                    nn.Conv2d(  64, 64, 1 ),
                    nn.Sequential( 
                        nn.Conv2d(  64, 64, 1 ),          nn.ReLU(),
                        ParallelModule(
                            nn.Conv2d(  64, 32, (3,1), padding=(1,0) ),
                            nn.Conv2d(  64, 32, (1,3), padding=(0,1) ),
                        ),
                    ),
                    nn.Sequential( 
                        nn.Conv2d(  64, 64, 1 ),          nn.ReLU(),
                        nn.Conv2d(  64, 64, 3, padding=1 ), nn.ReLU(),
                        ParallelModule(
                            nn.Conv2d(  64, 32, (3,1), padding=(1,0) ),
                            nn.Conv2d(  64, 32, (1,3), padding=(0,1) ),
                        ),
                    ),
                    nn.Sequential( 
                        #PrinterModule(),
                        nn.AvgPool2d( 3, stride=1, padding=1 ),
                        nn.Conv2d(  64, 64, 1 ),
                    ),
                ),
                nn.ReLU(),
                nn.Conv2d( 256, 64, 1 ),                  nn.ReLU(),
    
                nn.Conv2d( 64, 128, 3, padding=1 ),       nn.ReLU(),
                nn.MaxPool2d( 3, stride=2 ),              nn.Dropout2d( 0.5 ),
                nn.Flatten(),
                nn.Linear( 4608, 128 ),                   nn.ReLU(),
                nn.Linear(  128,  10 ),                   nn.LogSoftmax( dim=1 ),
            )
    
        def forward(self, x):
            return self.net.forward( x )