We are training a resnet model using the CIFAR 10 dataset and we are trying to do the following:
After we have trained the model itself we want to simulate dropout during the model evaluation when we are feeding the test data to it. I know it might sound weird because dropout is a regularization mechanism, but we are doing it as part of an experiment
One option that we are considering to try to to use the state_dict
, create a deep copy to have the original values, and then modify values it in manually.
We also saw that net.eval()
is changing the dropout layer into eval mode instead of training mode, maybe there is a way to utilize this mechanism to simulate dropout during evaluation ?
I want to ask if there are better ways to achieve what I am trying to do ?
The dropout module is disabled during evaluation mode, i.e. after nn.Module.eval
is called on your model. If you want it to be enabled after calling .eval
then you can call nn.Module.train
on each of the nn.Dropout
modules within your model. The nn.Module.apply
method makes this pretty easy. You can encapsulate this all in an override of your model's train
method.
class MyModel(nn.Module):
... # your model's implementation
# note that self.eval() just calls self.train(False) which is why
# we are overriding train
def train(self, mode=True):
# recursively applies train mode to self and submodules
super().train(mode)
# When mode=false (i.e. .eval() is called), re-enable dropout layers
# by recursively applying this function to each submodule
def enable_dropout(mod: nn.Module):
if isinstance(mod, nn.Dropout):
mod.train()
if not mode:
self.apply(enable_dropout)
return self
model = MyModel ...
that way after calling model.eval()
, everything except dropout layers will behave as if they are in evaluation mode. Calling model.train()
will work as it did before.
Of course if you don't want to override the model you could do this outside of the model.
# if you don't want to override the train method then you can just do the
# same thing as the above snippet outside of the class method. E.g.
model.eval()
def enable_dropout(mod: nn.Module):
if isinstance(mod, nn.Dropout):
mod.train()
model.apply(enable_dropout)