Search code examples
pytorchopenai-gym

How to remove showing renderings in the Cartpole game from OpenAI in Pytorch


My current code is based off of Pytorch's example on their website where they use env.render() to make the next state. That makes the game run very slow and would like it to run much quicker without the renderings. Here is the the class that uses the render function and at the bottom is the full code. Any ideas how I can get this to work?

class CartPoleEnvManager(): #manage cartpole env
    def __init__(self, device):
        self.device = device
        self.env = gym.make('CartPole-v0').unwrapped #unwrapped so we can see other dynamics that there is no access to otherwise
        self.env.reset() #reset to starting state
        self.current_screen = None #screen initialization
        self.done = False #if episode is finished

    def reset(self):
        self.env.reset()
        self.current_screen = None

    def close(self): #close env
        self.env.close()

    def render(self, mode='human'): #render the current state to screen
        return self.env.render(mode)

    def num_actions_available(self): #returns # actions available to agent (2)
        return self.env.action_space.n

    def take_action(self, action):# step returns tuple containing env observation, reward and diagnostic info -- all from taking a certain action
        _, reward, self.done, _ = self.env.step(action.item()) # only reward and done status are of importance
        return torch.tensor([reward], device=self.device)
    #####action is a tensor, action.item() gives a number, what step wants

    def just_starting(self):
        return self.current_screen is None

    def get_state(self): #return to the current state of env in the form of a processed image of the screen
        if self.just_starting() or self.done:
            self.current_screen = self.get_processed_screen() #state = processed image of diff of 2 separate screens
            black_screen = torch.zeros_like(self.current_screen)
            return black_screen
        else:
            s1 = self.current_screen
            s2 = self.get_processed_screen() ####what is get_processed_screen? 
            self.current_screen = s2
            return s2 - s1 # this represents a single state

    def get_screen_height(self):
        screen = self.get_processed_screen()
        return screen.shape[2]

    def get_screen_width(self):
        screen = self.get_processed_screen()
        return screen.shape[3]

    def get_processed_screen(self):
        screen = self.env.render(mode='rgb_array').transpose((2, 0, 1)) # PyTorch expects CHW
        screen = self.crop_screen(screen)
        return self.transform_screen_data(screen)

    def crop_screen(self, screen):
        screen_height = screen.shape[1]

        # Strip off top and bottom
        top = int(screen_height * 0.4)
        bottom = int(screen_height * 0.8)
        screen = screen[:, top:bottom, :]
        return screen

    def transform_screen_data(self, screen):       
        # Convert to float, rescale, convert to tensor
        screen = np.ascontiguousarray(screen, dtype=np.float32) / 255
        screen = torch.from_numpy(screen)

        # Use torchvision package to compose image transforms
        resize = T.Compose([
            T.ToPILImage()
            ,T.Resize((40,90))
            ,T.ToTensor()
        ])


        return resize(screen).unsqueeze(0).to(self.device) # add a batch dimension (BCHW)

And the full code:

import gym
import math
import random
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple
from itertools import count
from PIL import Image
import torch
import torch.nn as nnfr
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T
import pennylane as qml
import numpy as np
import torch
import torch.nn as nn
from torch.nn.functional import relu, sigmoid



class DQN(nn.Module):
    def __init__(self, img_height, img_width):
        super().__init__()

        self.fc1 = nn.Linear(in_features=img_height*img_width*3, out_features=64)   
        self.fc2 = nn.Linear(in_features=64, out_features=48)
        self.fc3 = nn.Linear(in_features=48, out_features=32)
        self.out = nn.Linear(in_features=32, out_features=2)        



    def forward(self, b):
        b = b.flatten(start_dim=1)
        #t = F.relu(clayer_out)
        b = F.relu(self.fc1(b))
        b = F.relu(self.fc2(b))
        b = F.relu(self.fc3(b))
        b = self.out(b)
        return b



is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython: from IPython import display

Experience = namedtuple(
    'Experience',
    ('state', 'action', 'next_state', 'reward')
) # use experiences to train network



class ReplayMemory():
    def __init__(self, capacity): # replay memory has set capacity, only need to
        self.capacity = capacity # initialize capacity for a ReplayMemory object
        self.memory = [] #memory will hold the experiences
        self.push_count = 0 #to show how many experiences we've added to memory

    def push(self, experience): # to add experiences in memory
        if len(self.memory) < self.capacity: # ensuring that the length of memory doesn't exceed set capacity
            self.memory.append(experience) 
        else:
            self.memory[self.push_count % self.capacity] = experience # if memory greater than capacity, 
         # then the new experiences will be put to the front of memory, erasing the 
                            # oldest experiences in the memory array
        self.push_count += 1 
    def sample(self, batch_size): # returns a randome sample of experiences, to later use for 
        return random.sample(self.memory, batch_size) #random.sample(sequence, k), of the sequence, gives k randomly chosen experiences

    def can_provide_sample(self, batch_size): #to sample, batch size needs to be bigger than memory -- this is important at the beginning
        return len(self.memory) >= batch_size


class EpsilonGreedyStrategy(): #explor vs. exploitation
    def __init__(self, start, end, decay):
        self.start = start
        self.end = end
        self.decay = decay

    def get_exploration_rate(self, current_step):  ####this was not explained in the video, woher kommt
        return self.end + self.start*(1/(1+self.decay*current_step))#self.end + (self.start - self.end) * \
            #math.exp(-1. * current_step * self.decay)


class LearningRate():
    def __init__(self, start, end, decay, current_step):
        self.start = start
        self.end = end
        self.decay = decay
        self.current_step = current_step
    def get_learning_rate(self, current_step):
        self.current_step += 1
        return self.end + self.start*(1/(1+self.decay*current_step))


class lr(): # learning rate class needed. Left for possible future use, need to update things beforehand
    def __init__(self, learning_rate):
        self.current_step = 0
        self.learning_rate = learning_rate
    def update_lr(self):    
        lrrate = learning_rate.get_learning_rate(self.current_step)
        self.current_step +=1
class Agent():
    def __init__(self, strategy, num_actions, device): # when we later create an agent object, need to get strategy from epsilon, num_actions = how many actions from a given state (2 for this game), device is the device in pytorch for tensor calculations CPU or GPU
        self.current_step = 0 # current step number in the environment
        self.strategy = strategy
        self.num_actions = num_actions
        self.device = device

    def select_action(self, state, policy_net): #policy_net is the policy trained by DQN
        rate = strategy.get_exploration_rate(self.current_step)
        self.current_step += 1

        if rate > random.random():
            action = random.randrange(self.num_actions)
            return torch.tensor([action]).to(self.device) # explore      
        else:
            with torch.no_grad(): #turn off gradient tracking
                return policy_net(state).argmax(dim=1).to(self.device) # exploit



class CartPoleEnvManager(): #manage cartpole env
    def __init__(self, device):
        self.device = device
        self.env = gym.make('CartPole-v0').unwrapped #unwrapped so we can see other dynamics that there is no access to otherwise
        self.env.reset() #reset to starting state
        self.current_screen = None #screen initialization
        self.done = False #if episode is finished

    def reset(self):
        self.env.reset()
        self.current_screen = None

    def close(self): #close env
        self.env.close()

    def render(self, mode='human'): #render the current state to screen
        return self.env.render(mode)

    def num_actions_available(self): #returns # actions available to agent (2)
        return self.env.action_space.n

    def take_action(self, action):# step returns tuple containing env observation, reward and diagnostic info -- all from taking a certain action
        _, reward, self.done, _ = self.env.step(action.item()) # only reward and done status are of importance
        return torch.tensor([reward], device=self.device)
    #####action is a tensor, action.item() gives a number, what step wants

    def just_starting(self):
        return self.current_screen is None

    def get_state(self): #return to the current state of env in the form of a processed image of the screen
        if self.just_starting() or self.done:
            self.current_screen = self.get_processed_screen() #state = processed image of diff of 2 separate screens
            black_screen = torch.zeros_like(self.current_screen)
            return black_screen
        else:
            s1 = self.current_screen
            s2 = self.get_processed_screen() ####what is get_processed_screen? 
            self.current_screen = s2
            return s2 - s1 # this represents a single state

    def get_screen_height(self):
        screen = self.get_processed_screen()
        return screen.shape[2]

    def get_screen_width(self):
        screen = self.get_processed_screen()
        return screen.shape[3]

    def get_processed_screen(self):
        screen = self.env.render(mode='rgb_array').transpose((2, 0, 1)) # PyTorch expects CHW
        screen = self.crop_screen(screen)
        return self.transform_screen_data(screen)

    def crop_screen(self, screen):
        screen_height = screen.shape[1]

        # Strip off top and bottom
        top = int(screen_height * 0.4)
        bottom = int(screen_height * 0.8)
        screen = screen[:, top:bottom, :]
        return screen

    def transform_screen_data(self, screen):       
        # Convert to float, rescale, convert to tensor
        screen = np.ascontiguousarray(screen, dtype=np.float32) / 255
        screen = torch.from_numpy(screen)

        # Use torchvision package to compose image transforms
        resize = T.Compose([
            T.ToPILImage()
            ,T.Resize((40,90))
            ,T.ToTensor()
        ])


        return resize(screen).unsqueeze(0).to(self.device) # add a batch dimension (BCHW)


def plot(values, moving_avg_period):
    plt.figure(2)
    plt.clf()        
    plt.title('Training...')
    plt.xlabel('Episode')
    plt.ylabel('Duration')
    plt.plot(values)

    moving_avg = get_moving_average(moving_avg_period, values)
    plt.plot(moving_avg)    
    plt.pause(0.001)
    print("Episode", len(values), "\n", \
          moving_avg_period, "episode moving avg:", moving_avg[-1])
    if is_ipython: display.clear_output(wait=True)

def get_moving_average(period, values):
    values = torch.tensor(values, dtype=torch.float)
    if len(values) >= period:
        moving_avg = values.unfold(dimension=0, size=period, step=1) \
            .mean(dim=1).flatten(start_dim=0)
        moving_avg = torch.cat((torch.zeros(period-1), moving_avg))
        return moving_avg.numpy()
    else:
        moving_avg = torch.zeros(len(values))
        return moving_avg.numpy()


def extract_tensors(experiences):
    # Convert batch of Experiences to Experience of batches
    batch = Experience(*zip(*experiences))

    t1 = torch.cat(batch.state)
    t2 = torch.cat(batch.action)
    t3 = torch.cat(batch.reward)
    t4 = torch.cat(batch.next_state)

    return (t1,t2,t3,t4)


class QValues():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    @staticmethod
    def get_current(policy_net, states, actions):
        return policy_net(states).gather(dim=1, index=actions.unsqueeze(-1))

    @staticmethod        
    def get_next(target_net, next_states):                
        final_state_locations = next_states.flatten(start_dim=1) \
            .max(dim=1)[0].eq(0).type(torch.bool)
        non_final_state_locations = (final_state_locations == False)
        non_final_states = next_states[non_final_state_locations]
        batch_size = next_states.shape[0]
        values = torch.zeros(batch_size).to(QValues.device)
        values[non_final_state_locations] = target_net(non_final_states).max(dim=1)[0].detach()
        return values


batch_size = 128
gamma = 0.999
eps_start = 1
eps_end = 0.01
eps_decay = 0.0005
target_update = 10
memory_size = 500000
lr_start = 0.01
lr_end = 0.00001
lr_decay = 0.00009
num_episodes = 1000 # run for more episodes for better results

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
em = CartPoleEnvManager(device)
strategy = EpsilonGreedyStrategy(eps_start, eps_end, eps_decay)
agent = Agent(strategy, em.num_actions_available(), device)
memory = ReplayMemory(memory_size)


policy_net = DQN(em.get_screen_height(), em.get_screen_width()).to(device)
target_net = DQN(em.get_screen_height(), em.get_screen_width()).to(device)
target_net.load_state_dict(policy_net.state_dict())

target_net.eval() #tells pytorch that target_net is only used for inference, not training
optimizer = optim.Adam(params=policy_net.parameters(), lr=0.01)

i = 0
episode_durations = []
for episode in range(num_episodes): #iterate over each episode
    em.reset()
    state = em.get_state()

    for timestep in count():
        action = agent.select_action(state, policy_net)
        reward = em.take_action(action)
        next_state = em.get_state()
        memory.push(Experience(state, action, next_state, reward))
        state = next_state
        i = 0
        if memory.can_provide_sample(batch_size):
            scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.9)
            experiences = memory.sample(batch_size)
            states, actions, rewards, next_states = extract_tensors(experiences)

            current_q_values = QValues.get_current(policy_net, states, actions)
            next_q_values = QValues.get_next(target_net, next_states) #will get the max qvalues of the next state, q values of next state are used via next state
            target_q_values = (next_q_values * gamma) + rewards

            loss = F.mse_loss(current_q_values, target_q_values.unsqueeze(1))
            optimizer.zero_grad() # sets the gradiesnt of all weights n biases in policy_net to zero
            loss.backward() #computes gradient of loss with respect to all weights n biases in the policy net
            optimizer.step() # updates the weights n biases with the gradients that were computed form loss.backwards
            scheduler.step()

        if em.done:
            episode_durations.append(timestep)
            plot(episode_durations, 100)
            break
    if episode % target_update == 0:
        target_net.load_state_dict(policy_net.state_dict()) 


em.close()

Solution

  • You're using a "hacked" (or patched if you will) version of CartPole environment which in effect replacing real state CartPole-v0 returns with rendered image. So your code is trying to train a policy which is taking image as an input rather than 4-values feature array original CartPole-v0 is returning.

    If you look closer when you call

     _, reward, self.done, _ = self.env.step(action.item())
    

    the first element _ is actual state of original CartPole-v0 env.

    Then instead of using that the class you have is doing rendering and returning image as input for training.

    So for the existing task (effectively state is an image) you can't really skip rendering since it is a part of preparing inputs for the policy.