Search code examples
pytorchreinforcement-learningq-learningdqn

Pytorch DQN, DDQN using .detach() caused very wield loss (increases exponentially) and do not learn at all


Here is my implementation of DQN and DDQN for CartPole-v0 which I think is correct.

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import gym
import torch.optim as optim
import random
import os
import time


class NETWORK(torch.nn.Module):
    def __init__(self, input_dim: int, output_dim: int, hidden_dim: int) -> None:

        super(NETWORK, self).__init__()

        self.layer1 = torch.nn.Sequential(
            torch.nn.Linear(input_dim, hidden_dim),
            torch.nn.ReLU()
        )

        self.layer2 = torch.nn.Sequential(
            torch.nn.Linear(hidden_dim, hidden_dim),
            torch.nn.ReLU()
        )

        self.final = torch.nn.Linear(hidden_dim, output_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.final(x)

        return x

class ReplayBuffer(object):
    def __init__(self, capacity=50000):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, s0, a0, r, s1):
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = (s0, a0, r, s1)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size=64):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

class DQN(object):
    def __init__(self):
      self.state_dim = 4
      self.action_dim = 2
      self.lr = 0.001
      self.discount_factor = 0.99
      self.epsilon = 1
      self.epsilon_decay = 0.95
      self.num_train = 0
      self.num_train_episodes = 0
      self.batch_size = 64

      self.predict_network = NETWORK(input_dim=4, output_dim=2, hidden_dim=16).double()

      self.memory = ReplayBuffer(capacity=50000)
      self.optimizer = torch.optim.Adam(self.predict_network.parameters(), lr=self.lr)
      self.loss = 0

    def select_action(self, states: np.ndarray) -> int:
      if np.random.uniform(0, 1) < self.epsilon:
        return np.random.choice(self.action_dim)
      else:
        states = torch.from_numpy(states).unsqueeze_(dim=0)
        with torch.no_grad():
          Q_values = self.predict_network(states)
          action = torch.argmax(Q_values).item()
        return action

    def policy(self, states: np.ndarray) -> int:
      states = torch.from_numpy(states).unsqueeze_(dim=0)
      with torch.no_grad():
        Q_values = self.predict_network(states)
        action = torch.argmax(Q_values).item()
      return action

    def train(self, s0, a0, r, s1, sign):
      if sign == 1:
        self.num_train_episodes += 1
        if self.epsilon > 0.01:
          self.epsilon = max(self.epsilon * self.epsilon_decay, 0.01)
        return

      self.num_train += 1
      self.memory.push(s0, a0, r, s1)
      if len(self.memory) < self.batch_size:
        return
      
      batch = self.memory.sample(self.batch_size)
      state_batch = torch.from_numpy(np.stack([b[0] for b in batch]))
      action_batch = torch.from_numpy(np.stack([b[1] for b in batch]))
      reward_batch = torch.from_numpy(np.stack([b[2] for b in batch]))
      next_state_batch = torch.from_numpy(np.stack([b[3] for b in batch]))

      Q_values = self.predict_network(state_batch)[torch.arange(self.batch_size), action_batch]
      
      next_state_Q_values = self.predict_network(next_state_batch).max(dim=1)[0]
      
      Q_targets = self.discount_factor * next_state_Q_values + reward_batch
      
      loss = F.mse_loss(Q_values, Q_targets.detach())
      
      self.optimizer.zero_grad()
      loss.backward()
      self.optimizer.step()

      self.loss = loss.data.item()

class DDQN(object):
    def __init__(self):
      self.state_dim = 4
      self.action_dim = 2
      self.lr = 0.001
      self.discount_factor = 0.9
      self.epsilon = 1
      self.epsilon_decay = 0.95
      self.num_train = 0
      self.num_train_episodes = 0
      self.batch_size = 64

      self.predict_network = NETWORK(input_dim=4, output_dim=2, hidden_dim=16).double()
      self.target_network = NETWORK(input_dim=4, output_dim=2, hidden_dim=16).double()
      self.target_network.load_state_dict(self.predict_network.state_dict())
      self.target_network.eval()

      self.memory = ReplayBuffer(capacity=50000)
      self.optimizer = torch.optim.Adam(self.predict_network.parameters(), lr=self.lr)

      self.loss = 0

    def select_action(self, states: np.ndarray) -> int:
      if np.random.uniform(0, 1) < self.epsilon:
        return np.random.choice(self.action_dim)
      else:
        states = torch.from_numpy(states).unsqueeze_(dim=0)
        with torch.no_grad():
          Q_values = self.predict_network(states)
        action = torch.argmax(Q_values).item()
        return action

    def policy(self, states: np.ndarray) -> int:
      states = torch.from_numpy(states).unsqueeze_(dim=0)
      with torch.no_grad():
        Q_values = self.predict_network(states)
        action = torch.argmax(Q_values).item()
      return action

    def train(self, s0, a0, r, s1, sign):
      if sign == 1:
        self.num_train_episodes += 1
        if self.num_train_episodes % 2 == 0:
          self.target_network.load_state_dict(self.predict_network.state_dict())
          self.target_network.eval()
          
          if self.epsilon > 0.01:
            self.epsilon = max(self.epsilon * self.epsilon_decay, 0.01)
        return

      self.num_train += 1
      self.memory.push(s0, a0, r, s1)
      if len(self.memory) < self.batch_size:
        return
      batch = self.memory.sample(self.batch_size)
      state_batch = torch.from_numpy(np.stack([b[0] for b in batch]))
      action_batch = torch.from_numpy(np.stack([b[1] for b in batch]))
      reward_batch = torch.from_numpy(np.stack([b[2] for b in batch]))
      next_state_batch = torch.from_numpy(np.stack([b[3] for b in batch]))
      
      Q_values = self.predict_network(state_batch)[torch.arange(self.batch_size), action_batch]
      
      next_state_action_batch = torch.argmax(self.predict_network(next_state_batch), dim=1)
      
      next_state_Q_values = self.target_network(next_state_batch)[torch.arange(self.batch_size), next_state_action_batch]
      
      Q_targets = self.discount_factor * next_state_Q_values + reward_batch
      
      loss = F.smooth_l1_loss(Q_values, Q_targets.detach())
      self.optimizer.zero_grad()
      loss.backward()
      self.optimizer.step()

      self.loss = loss.data.item()

I used following to evaluate and train my DQN and DDQN.

def eval_policy(agent, env_name, eval_episodes=10):
    eval_env = gym.make(env_name)
    avg_reward = 0.
    for _ in range(eval_episodes):
        state, done = eval_env.reset(), False
        while not done:
            action = agent.policy(state)
            state, reward, done, _ = eval_env.step(action)
            avg_reward += reward
    avg_reward /= eval_episodes
    print("---------------------------------------")
    print(f"Evaluation over {eval_episodes} episodes: {avg_reward:.3f}")
    print("---------------------------------------")
    return avg_reward


env_name = 'CartPole-v0'
env = gym.make(env_name)
    
agent = DQN() # agent = DDQN()

for i in range(1000):
    state, done = env.reset(), False
    episodic_reward = 0
    while not done:
        action = agent.select_action(np.squeeze(state))
        next_state, reward, done, info = env.step(action)
        episodic_reward += reward      
        sign = 1 if done else 0
        agent.train(state, action, reward, next_state, sign)
        state = next_state        
    print(f'episode: {i}, reward: {episodic_reward}')  
    if i % 20 == 0:
        eval_reward = eval_policy(agent, env_name, eval_episodes=50)
        if eval_reward >= 195:
            print("Problem solved in {} episodes".format(i + 1))
            break

The thing is that my DQN networks do not train and the loss grow exponentially using target.detach() in loss calculation. If I do not use .detach(), the DQN object would train but I believe that is not the correct way. For DDQN, my networks always do not train. Can anyone give some advice on where might be wrong?


Solution

  • so one mistake in your implementation is that you never add the end of an episode to your replay buffer. In your train function you return if sign==1 (end of the episode). Remove that return and adjust the target calculation via (1-dones)*... in case you sample a transition of the end of an episode. The reason why the end of the episode is important is that it is the only experience is where the target is not approximated via bootstrapping. Then DQN trains. For reproducibility I used a discount rate of 0.99 and the seed 2020 (for torch, numpy and the gym environment). I achieved a reward of 199.100 after 241 episodes of training.

    Hope that helps, code is very readable btw.