Search code examples
pythondeep-learningneural-networkpytorchearly-stopping

early stopping in PyTorch


I tried to implement an early stopping function to avoid my neural network model overfit. I'm pretty sure that the logic is fine, but for some reason, it doesn't work. I want that when the validation loss is greater than the training loss over some epochs, the early stopping function returns True. But it returns False all the time, even though the validation loss becomes a lot greater than the training loss. Could you see where is the problem, please?

early stopping function

def early_stopping(train_loss, validation_loss, min_delta, tolerance):

    counter = 0
    if (validation_loss - train_loss) > min_delta:
        counter +=1
        if counter >= tolerance:
          return True

calling the function during the training

for i in range(epochs):
    
    print(f"Epoch {i+1}")
    epoch_train_loss, pred = train_one_epoch(model, train_dataloader, loss_func, optimiser, device)
    train_loss.append(epoch_train_loss)

    # validation 

    with torch.no_grad(): 
       epoch_validate_loss = validate_one_epoch(model, validate_dataloader, loss_func, device)
       validation_loss.append(epoch_validate_loss)
    
    # early stopping
    if early_stopping(epoch_train_loss, epoch_validate_loss, min_delta=10, tolerance = 20):
      print("We are at epoch:", i)
      break

EDIT: The train and validation loss: enter image description here enter image description here

EDIT2:

def train_validate (model, train_dataloader, validate_dataloader, loss_func, optimiser, device, epochs):
    preds = []
    train_loss =  []
    validation_loss = []
    min_delta = 5
    

    for e in range(epochs):
        
        print(f"Epoch {e+1}")
        epoch_train_loss, pred = train_one_epoch(model, train_dataloader, loss_func, optimiser, device)
        train_loss.append(epoch_train_loss)

        # validation 
        with torch.no_grad(): 
           epoch_validate_loss = validate_one_epoch(model, validate_dataloader, loss_func, device)
           validation_loss.append(epoch_validate_loss)
        
        # early stopping
        early_stopping = EarlyStopping(tolerance=2, min_delta=5)
        early_stopping(epoch_train_loss, epoch_validate_loss)
        if early_stopping.early_stop:
            print("We are at epoch:", e)
            break

    return train_loss, validation_loss

Solution

  • The problem with your implementation is that whenever you call early_stopping() the counter is re-initialized with 0.

    Here is working solution using an oo-oriented approch with __call__() and __init__() instead:

    class EarlyStopping:
        def __init__(self, tolerance=5, min_delta=0):
    
            self.tolerance = tolerance
            self.min_delta = min_delta
            self.counter = 0
            self.early_stop = False
    
        def __call__(self, train_loss, validation_loss):
            if (validation_loss - train_loss) > self.min_delta:
                self.counter +=1
                if self.counter >= self.tolerance:  
                    self.early_stop = True
    

    Call it like that:

    early_stopping = EarlyStopping(tolerance=5, min_delta=10)
    
    for i in range(epochs):
        
        print(f"Epoch {i+1}")
        epoch_train_loss, pred = train_one_epoch(model, train_dataloader, loss_func, optimiser, device)
        train_loss.append(epoch_train_loss)
    
        # validation 
        with torch.no_grad(): 
           epoch_validate_loss = validate_one_epoch(model, validate_dataloader, loss_func, device)
           validation_loss.append(epoch_validate_loss)
        
        # early stopping
        early_stopping(epoch_train_loss, epoch_validate_loss)
        if early_stopping.early_stop:
          print("We are at epoch:", i)
          break
    

    Example:

    early_stopping = EarlyStopping(tolerance=2, min_delta=5)
    
    train_loss = [
        642.14990234,
        601.29278564,
        561.98400879,
        530.01501465,
        497.1098938,
        466.92709351,
        438.2364502,
        413.76028442,
        391.5090332,
        370.79074097,
    ]
    validate_loss = [
        509.13619995,
        497.3125,
        506.17315674,
        497.68960571,
        505.69918823,
        459.78610229,
        480.25592041,
        418.08630371,
        446.42675781,
        372.09902954,
    ]
    
    for i in range(len(train_loss)):
    
        early_stopping(train_loss[i], validate_loss[i])
        print(f"loss: {train_loss[i]} : {validate_loss[i]}")
        if early_stopping.early_stop:
            print("We are at epoch:", i)
            break
    
    

    Output:

    loss: 642.14990234 : 509.13619995
    loss: 601.29278564 : 497.3125
    loss: 561.98400879 : 506.17315674
    loss: 530.01501465 : 497.68960571
    loss: 497.1098938 : 505.69918823
    loss: 466.92709351 : 459.78610229
    loss: 438.2364502 : 480.25592041
    We are at epoch: 6