Search code examples
python-3.xskorch

Get the validation/train loss from Skorch fit


Is there a way to get the train/validation loss from a Skorch fit in e.g a list (if you want to do some plotting, statistics)?


Solution

  • You can use the history (which allows slicing over time) to get this information. For example:

    train_loss = net.history[:, 'train_loss']
    

    which returns the train_loss for each recorded epoch.

    Here is an example based on the following.

    import numpy as np
    import torch
    import torch.nn.functional as F
    from matplotlib import pyplot as plt
    from sklearn.datasets import make_classification
    from skorch import NeuralNetClassifier
    from torch import nn
    
    torch.manual_seed(0)
    
    class ClassifierModule(nn.Module):
        def __init__(
                self,
                num_units=10,
                nonlin=F.relu,
                dropout=0.5,
        ):
            super(ClassifierModule, self).__init__()
            self.num_units = num_units
            self.nonlin = nonlin
            self.dropout = dropout
    
            self.dense0 = nn.Linear(20, num_units)
            self.nonlin = nonlin
            self.dropout = nn.Dropout(dropout)
            self.dense1 = nn.Linear(num_units, 10)
            self.output = nn.Linear(10, 2)
    
        def forward(self, X, **kwargs):
            X = self.nonlin(self.dense0(X))
            X = self.dropout(X)
            X = F.relu(self.dense1(X))
            X = F.softmax(self.output(X), dim=-1)
            return X
    
    net = NeuralNetClassifier(
        ClassifierModule,
        max_epochs=20,
        lr=0.1,
        # device='cuda',  # uncomment this to train with CUDA
    )
    
    X, y = make_classification(1000, 20, n_informative=10, random_state=0)
    X, y = X.astype(np.float32), y.astype(np.int64)
    
    net.fit(X, y)
    
    train_loss = net.history[:, 'train_loss']
    valid_loss = net.history[:, 'valid_loss']
    
    plt.plot(train_loss, 'o-', label='training')
    plt.plot(valid_loss, 'o-', label='validation')
    plt.legend()
    plt.show()
    

    And results:

    resulting plot