Search code examples
pytorchpytorch-lightning

Pytorch dynamic amount of Layers?


I am trying to specify a dynamic amount of layers, which I seem to be doing wrong. My issue is that when I define the 100 layers here, I will get an error in the forward step. But when I define the layer properly it works? Below simplified example

class PredictFromEmbeddParaSmall(LightningModule):
    def __init__(self, hyperparams={'lr': 0.0001}):
        super(PredictFromEmbeddParaSmall, self).__init__()
        #Input is something like tensor.size=[768*100]
        self.TO_ILLUSTRATE = nn.Linear(768, 5)
        self.enc_ref=[]
        for i in range(100):
            self.enc_red.append(nn.Linear(768, 5))
        # gather the layers output sth
        self.dense_simple1 = nn.Linear(5*100, 2)
        self.output = nn.Sigmoid()
    def forward(self, x):
        # first input to enc_red
        x_vecs = []
        for i in range(self.para_count):
            layer = self.enc_red[i]
            # The first dim is the batch size here, output is correct
            processed_slice = x[:, i * 768:(i + 1) * 768]
            # This works and give the out of size 5
            rand = self.TO_ILLUSTRATE(processed_slice)
            #This will fail? Error below
            ret = layer(processed_slice)
            #more things happening we can ignore right now since we fail earlier

I get this error when executing "ret = layer.forward(processed_slice)"

RuntimeError: Expected object of device type cuda but got device type cpu for argument #1 'self' in call to _th_addmm

Is there a smarter way to program this? OR solve the error?


Solution

  • You should use a ModuleList from pytorch instead of a list: https://pytorch.org/docs/master/generated/torch.nn.ModuleList.html . That is because Pytorch has to keep a graph with all modules of your model, if you just add them in a list they are not properly indexed in the graph, resulting in the error you faced.

    Your coude should be something alike:

    class PredictFromEmbeddParaSmall(LightningModule):
        def __init__(self, hyperparams={'lr': 0.0001}):
            super(PredictFromEmbeddParaSmall, self).__init__()
            #Input is something like tensor.size=[768*100]
            self.TO_ILLUSTRATE = nn.Linear(768, 5)
            self.enc_ref=nn.ModuleList()                     # << MODIFIED LINE <<
            for i in range(100):
                self.enc_red.append(nn.Linear(768, 5))
            # gather the layers output sth
            self.dense_simple1 = nn.Linear(5*100, 2)
            self.output = nn.Sigmoid()
        def forward(self, x):
            # first input to enc_red
            x_vecs = []
            for i in range(self.para_count):
                layer = self.enc_red[i]
                # The first dim is the batch size here, output is correct
                processed_slice = x[:, i * 768:(i + 1) * 768]
                # This works and give the out of size 5
                rand = self.TO_ILLUSTRATE(processed_slice)
                #This will fail? Error below
                ret = layer(processed_slice)
                #more things happening we can ignore right now since we fail earlier
    

    Then it should work all right!

    Edit: alternative way.

    Instead of using ModuleList you can also just use nn.Sequential, this allows you to avoid using the for loop in the forward pass. That also means that you will not have access to intermediary activations, so that is not the solution for you if you need them.

    class PredictFromEmbeddParaSmall(LightningModule):
        def __init__(self, hyperparams={'lr': 0.0001}):
            super(PredictFromEmbeddParaSmall, self).__init__()
            #Input is something like tensor.size=[768*100]
            self.TO_ILLUSTRATE = nn.Linear(768, 5)
            self.enc_ref=[]
            for i in range(100):
                self.enc_red.append(nn.Linear(768, 5))
    
            self.enc_red = nn.Seqential(*self.enc_ref)       # << MODIFIED LINE <<
            # gather the layers output sth
            self.dense_simple1 = nn.Linear(5*100, 2)
            self.output = nn.Sigmoid()
        def forward(self, x):
            # first input to enc_red
            x_vecs = []
            out = self.enc_red(x)                            # << MODIFIED LINE <<