Search code examples
pythondeep-learningpytorchreinforcement-learning

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed)


I am trying to get to grips with Pytorch and I wanted to try to reproduce this code:

https://github.com/andy-psai/MountainCar_ActorCritic/blob/master/RL%20Blog%20FINAL%20MEDIUM%20code%2002_12_19.ipynb

in Pytorch.

I am having a problem in that this error is being returned:

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

A similar question said to use zero_grad() again after the optimizer step, but this hasn't resolved the issue.

I've included the entire code below so hopefully it should be reproduceable.

Any advice would be much appreciated.

import gym
import os
import os.path as osp
import time
import math
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Normal

env = gym.envs.make("MountainCarContinuous-v0")

# Value function

class Value(nn.Module):
    def __init__(self, dim_states):
        super(Value, self).__init__()
        self.net = nn.Sequential(
        nn.Linear(dim_states, 400),
        nn.ReLU(),
        nn.Linear(400,400),
        nn.ReLU(),
        nn.Linear(400, 1)
    )
        self.optimizer = optim.Adam(self.parameters(), lr = 1e-3)
        self.criterion = nn.MSELoss()

    def forward(self, state):
        return self.net(torch.from_numpy(state).float())

    def compute_return(self, output, target):
        self.optimizer.zero_grad()
        loss = self.criterion(output, target)
        loss.backward()
        self.optimizer.step()
        self.optimizer.zero_grad()

# Policy network

class Policy(nn.Module):
    def __init__(self, dim_states, env):
        super(Policy, self).__init__()

        self.hidden1 = nn.Linear(dim_states, 40)
        self.hidden2 = nn.Linear(40, 40)
        self.mu = nn.Linear(40, 1)
        self.sigma = nn.Linear(40,1)
        self.env = env

        self.optimizer = optim.Adam(self.parameters(), lr = 2e-5)

    def forward(self, state):
        state = torch.from_numpy(state).float()
        x = F.relu(self.hidden1(state))
        x = F.relu(self.hidden2(x))
        mu = self.mu(x)
        sigma = F.softmax(self.sigma(x), dim=-1)
        action_dist = Normal(mu, sigma)
        action_var = action_dist.rsample()
        action_var = torch.clip(action_var,
                        self.env.action_space.low[0],
                        self.env.action_space.high[0])

        return action_var, action_dist

    def compute_return(self, action, dist, td_error):
        self.optimizer.zero_grad()
        loss_actor = -dist.log_prob(action)*td_error
        loss_actor.backward()
        self.optimizer.step()
        self.optimizer.zero_grad()

# Normalise the state space
import sklearn
import sklearn.preprocessing

state_space_samples = np.array(
    [env.observation_space.sample() for x in range(10000)])
scaler = sklearn.preprocessing.StandardScaler()
scaler.fit(state_space_samples)

# Normaliser
def scale_state(state):
    scaled = scaler.transform([state])
    return scaled
##################################

# Parameters

lr_actor = 0.00002
lr_critic = 0.001

actor = Policy(2, env)
critic = Value(2)

# Training loop params
gamma = 0.99
num_episodes = 300

episode_history = []

for episode in range(num_episodes):

    # Receive initial state from E

    state = env.reset()
    reward_total = 0
    steps = 0
    done = False

    while not done:

        action, dist = actor(state)

        # print(np.squeeze(action))
        next_state, reward, done, _ = env.step(
                                        np.array([action.item()]))

        if episode % 50 == 0:
            env.render()

        steps += 1
        reward_total += reward

        # TD Target
        target = reward + gamma * np.squeeze(critic(next_state), axis=0)

        td_error = target - np.squeeze(critic(state), axis=0)

        # Update actor
        actor.compute_return(action, dist, td_error)

        # Update critic
        critic.compute_return(np.squeeze(critic(state), axis=0), target)


    episode_history.append(reward_total)

    print(f"Episode: {episode}, N Steps: {steps}, Cumulative reward {reward_total}")

    if np.mean(episode_history[-100:]) > 90 and len(episode_history) > 101:
        print("Solved")
        print(f"Mean cumulative reward over 100 episodes {np.mean(episode_history[-100:])}")

Solution

  • Problem lies in this snippet. When you create target variable, there is a forward pass through critic which generates a computation graph and critic(next_state) is the leaf node of that graph making target a part of the graph (you can check this by printing target which will show you grad_fn=<AddBackward0>). Finally, when you call critic.compute_return(critic_out, target), a new computation graph is generated and passing target(which is a part of the previous computation graph) causes a Runtime error.

    Solution is to call detach() on critic(next_state), this will free target variable and it will no longer be a part of the computation graph(again check by printing target).

    target = reward + gamma * np.squeeze(critic(next_state).detach(), axis=0)
    td_error = target - np.squeeze(critic(state), axis=0)
    
    # Update actor
    actor.compute_return(action, dist, td_error)
    
    # Update critic
    critic_out = np.squeeze(critic(state), axis=0)
    print(critic_out)
    critic.compute_return(critic_out, target)