During the training phase, I select the model parameters with the best performance metric.
if performance_metric.item()>max_performance:
max_performance= performance_metric.item()
torch.save(neural_net.state_dict(), PATH+'/best_model.pt')
This is the neural network model used:
class Neural_Net(nn.Module):
def __init__(self, M,shape_input,batch_size):
super(Neural_Net, self).__init__()
self.lstm = nn.LSTM(shape_input,M)
#self.dense1 = nn.Linear(shape_input,M)
self.dense1 = nn.Linear(M,M) #Used with the LSTM
torch.nn.init.xavier_uniform_(self.dense1.weight)
self.dense2 = nn.Linear(M,M)
torch.nn.init.xavier_uniform_(self.dense2.weight)
self.dense3 = nn.Linear(M,1)
torch.nn.init.xavier_uniform_(self.dense3.weight)
self.drop = nn.Dropout(0.7)
self.bachnorm1 = nn.BatchNorm1d(M)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
self.hidden_cell = (torch.zeros(1,batch_size,M),torch.zeros(1,batch_size,M))
def forward(self, x):
lstm_out, self.hidden_cell = self.lstm(x.view(1 ,len(x), -1), self.hidden_cell)
x = self.drop(self.relu(self.dense1(self.bachnorm1(lstm_out.view(len(x), -1)))))
x = self.drop(self.relu(self.dense2(x)))
x = self.relu(self.dense3(x))
return x
After that I load the model with the best parameters and set the evaluation mode:
neural_net.load_state_dict(torch.load(PATH+'/best_model.pt'))
neural_net.eval()
The results are completely random. When I set train()
the performance is similar to the selected best model parameter.
There is an important aspect of the eval() that I am forgetting? Is the batch normalization correctly used? I am using a batch the same size as in the training phase for the test phase.
Without knowing your batch size, training/test dataset size, or the training/test dataset discrepancies, this issue has been discussed on the pytorch forums previously here.
In my experience, it sounds very much like your latent training data representation in your model is significantly different to your validation data representation. The main advice I can provide is for you to try reducing the momentum of your batchnorm layer. It might be worth substituting a layernorm layer instead (which doesn't track a running mean/standard deviation) OR setting track_running_stats=False
in the batchnorm1d function and seeing if the problem persists.