Search code examples
pythonvalidationpytorchconv-neural-networkprediction

How to load a trained model to inference the predicted data


I trained and saved CNN model for 1000 epochs and now want to retrieve validation data (predicted images). In code below test_pred and test_real outputs predicted and real images in validation set. Should I load and run the saved model for another 1 epoch to retrieve predicted images (this will end in CUDA out of memory since data are huge)? Or there are other ways? You can see part of my code below:

for epoch in range(epochs):
    mse_train_losses= []
    mae_train_losses = []
    N_train = []
    mse_val_losses = []
    mae_val_losses = []
    N_test = []
    
    if save_model:
        if epoch % 50 ==0:
            checkpoint = {'state_dict' : model.state_dict(),'optimizer' : optimizer.state_dict()}
            save_checkpoint(checkpoint)
   
    model.train()
    for data in train_loader:

        x_train_batch, y_train_batch = data[0].to(device, 
            dtype=torch.float), data[1].to(device, dtype=torch.float)  
        y_train_pred = model(x_train_batch)            # 1) Forward pass
        mse_train_loss = criterion(y_train_batch, y_train_pred, x_train_batch, mse) 
        mae_train_loss = criterion(y_train_batch, y_train_pred, x_train_batch, l1loss)  
        
        optimizer.zero_grad()                   
        mse_train_loss.backward()                        
        optimizer.step()                        
        
        mse_train_losses.append(mse_train_loss.item())
        mae_train_losses.append(mae_train_loss.item())
        N_train.append(len(x_train_batch))
        
                       
    test_pred=[] 
    test_real=[]
    model.eval()
    with torch.no_grad():
        for data in test_loader:
            x_test_batch, y_test_batch = data[0].to(device, 
                dtype=torch.float), data[1].to(device, dtype=torch.float)

            y_test_pred = model(x_test_batch)
            mse_val_loss = criterion(y_test_batch, y_test_pred, x_test_batch, mse)
            mae_val_loss = criterion(y_test_batch, y_test_pred, x_test_batch, l1loss)
            
            mse_val_losses.append(mse_val_loss.item())
            mae_val_losses.append(mae_val_loss.item())
            N_test.append(len(x_test_batch))
            
            test_pred.append(y_test_pred)                        
            test_real.append(y_test_batch)

Solution

  • When you append it to the list try using .cpu() at the end like this:

    test_pred.append(t_test_pred.cpu())