Search code examples
pythonscikit-learnneural-networkpytorchskorch

How to pass parameters to forward function of my torch nn.module from skorch.NeuralNetClassifier.fit()


I have extended nn.Module to implement my network whose forward function is like this ...

def forward(self, X, **kwargs):

    batch_size, seq_len = X.size()

    length = kwargs['length']
    embedded = self.embedding(X) # [batch_size, seq_len, embedding_dim]
    if self.use_padding:
        if length is None:
            raise AttributeError("Length must be a tensor when using padding")
        embedded = nn.utils.rnn.pack_padded_sequence(embedded, length, batch_first=True)
        #print("Size of Embedded packed", embedded[0].size())


    hidden, cell = self.init_hidden(batch_size)
    if self.rnn_unit == 'rnn':
        out, _ = self.rnn(embedded, hidden)
    elif self.rnn_unit == 'lstm':
        out, (hidden, cell) = self.rnn(embedded, (hidden, cell))


    # unpack if padding was used
    if self.use_padding:
        out, _ = nn.utils.rnn.pad_packed_sequence(out, batch_first = True)

I initialized a skorch NeuralNetClassifier like this,

net = NeuralNetClassifier(
    model,
    criterion=nn.CrossEntropyLoss,
    optimizer=Adam, 
    max_epochs=8, 
    lr=0.01, 
    batch_size=32
)

Now if I call net.fit(X, y, length=X_len) it throws an error

TypeError: __call__() got an unexpected keyword argument 'length'

According to the documentation fit function expects a fit_params dictionary,

**fit_params : dict
   Additional parameters passed to the ``forward`` method of
   the module and to the ``self.train_split`` call.

and the source code always send my parameters to train_split where obviously my keyword argument would not be recognized.

Is there any way around to pass the arguments to my forward function?


Solution

  • The fit_params parameter is intended for passing information that is relevant to data splits and the model alike, like split groups.

    In your case, you are passing additional data to the module via fit_params which is not what it is intended for. In fact, you could easily run into trouble doing this if you, for example, enable batch shuffling on the train data loader since then your lengths and your data are misaligned.

    The best way to do this is already described in the answer to your question on the issue tracker:

    X_dict = {'X': X, 'length': X_len}
    net.fit(X_dict, y)
    

    Since skorch supports dicts you can simply add the length's to your input dict and have it both passed to the module, nicely batched and passed through the same data loader. In your module you can then access it via the parameters in forward:

    def forward(self, X, length):
         return ...
    

    Further documentation of this behaviour can be found in the docs.