Search code examples
neural-networkdeep-learningreinforcement-learningpytorch

Is there a way to use an external loss function in pytorch?


A typical skeleton of pytorch neural network has a forward() method, then we compute loss based on outputs of forward pass, and call backward() on that loss to update the gradients. What if my loss is determined externally (e.g. by running simulation in some RL environment). Can I still leverage this typical structure this way?

  • This might be a bit dumb as we no longer know exactly how much each element of output influences the loss, but perhaps there is some trickery I'm not aware of. Otherwise I'm not sure how neural nets can be used in combination with other RL algorithms.

Thank you!


Solution

  • In this case it appears easiest to me abstract the forward pass (your policy?) from the loss computation. This is because (as you note) in most scenarios, you will need to obtain a state (from your environment), then compute an action (essentially the forward pass), then feed that action back to the environment to obtain a reward/ loss from your environment.

    Of course, you could probably call your environment within the forward pass once you computed an action to then calculate the resultant loss. But why bother? It will get even more complicated (though possible) once you are are taking several steps in your environment until you obtain a reward/ loss.

    I would suggest you take a look at the following RL example for an application of policy gradients within openAI gym: https://github.com/pytorch/examples/blob/master/reinforcement_learning/reinforce.py#L43

    The essential ideas are:

    • Create a policy (as an nn.module) that takes in a state and returns a stochastic policy
    • Wrap the computation of a policy and the sampling of an action from the policy into one function.
    • Call this function repeatedly to take steps through your environment, record actions and rewards.
    • Once an episode is finished, register rewards and perform only now the back-propagation and gradient updates.

    While this example is specific to REINFORCE, the general idea of structuring your code is applicable to other RL algorithms. Besides, you'll find two other examples in the same repo.

    Hope this helps.