Search code examples
pythondeep-learningpytorch

pytorch simulate dropout when evaluating model against test data


We are training a resnet model using the CIFAR 10 dataset and we are trying to do the following:

After we have trained the model itself we want to simulate dropout during the model evaluation when we are feeding the test data to it. I know it might sound weird because dropout is a regularization mechanism, but we are doing it as part of an experiment

One option that we are considering to try to to use the state_dict, create a deep copy to have the original values, and then modify values it in manually.

We also saw that net.eval() is changing the dropout layer into eval mode instead of training mode, maybe there is a way to utilize this mechanism to simulate dropout during evaluation ?

I want to ask if there are better ways to achieve what I am trying to do ?


Solution

  • The dropout module is disabled during evaluation mode, i.e. after nn.Module.eval is called on your model. If you want it to be enabled after calling .eval then you can call nn.Module.train on each of the nn.Dropout modules within your model. The nn.Module.apply method makes this pretty easy. You can encapsulate this all in an override of your model's train method.

    class MyModel(nn.Module):
        ... # your model's implementation
            
        # note that self.eval() just calls self.train(False) which is why
        # we are overriding train
        def train(self, mode=True):
            # recursively applies train mode to self and submodules
            super().train(mode)
    
            # When mode=false (i.e. .eval() is called), re-enable dropout layers
            # by recursively applying this function to each submodule
            def enable_dropout(mod: nn.Module):
                if isinstance(mod, nn.Dropout):
                    mod.train()
    
            if not mode:
                self.apply(enable_dropout)
    
            return self
    
    
    model = MyModel ...
    

    that way after calling model.eval(), everything except dropout layers will behave as if they are in evaluation mode. Calling model.train() will work as it did before.

    Of course if you don't want to override the model you could do this outside of the model.

    # if you don't want to override the train method then you can just do the
    # same thing as the above snippet outside of the class method. E.g.
    
    model.eval()
    
    def enable_dropout(mod: nn.Module):
        if isinstance(mod, nn.Dropout):
            mod.train()
    
    model.apply(enable_dropout)