Search code examples
pythonpython-3.xpytorchboolean

torch.set_grad_enabled(False): TypeError: 'bool' object is not callable


I have a linear layer that is raising this error:

TypeError bool object is not callable PyTorch Python

I have tried to look at the file, grad_mode.py with no luck.

def __enter__(self) -> None:
        self.prev = torch.is_grad_enabled()
        torch.set_grad_enabled(False)

When I start my environment, I can set torch.set_grad_enabled(False) successfully. But when I run the code cell again I receive an error and I find out that torch.set_grad_enabled is now bool.

This is my agent:

import torch
import torch.nn as nn
import torch.nn.functional as F
import collections
import random
import math
import numpy as np
import torch.optim as optim
from typing import Type


class CardAgent(nn.Module):
    def __init__(self, params):
        super().__init__()
        self.first_layer = params["first_layer_size"]
        self.second_layer = params["second_layer_size"]
        self.third_layer = params["third_layer_size"]
        self.gamma = params["gamma"]
        self.learning_rate = params["learning_rate"]
        self.memory = collections.deque(maxlen= params["memory_size"])
        self.batch_size = params["batch_size"]
        self.weights_path = params["weights_path"]
        self.optimizer = None
        self.mask = None
        self.network()

    def network(self):
        self.requires_grad_ = False
        self.fc1 = nn.Linear(57, self.first_layer)
        self.fc2 = nn.Linear(self.first_layer, self.second_layer)
        self.fc3 = nn.Linear(self.second_layer, self.third_layer)
        self.fc4 = nn.Linear(self.third_layer, 60)


    def forward(self, observation):
        x = F.relu(self.fc1(observation))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = F.relu(self.fc4(x))
        x = F.softmax(self.fc4(x), dim=-1)

        if self.mask is not None:
          print(x)
          return x * self.mask
        return x

    def remember(self, observation, move, reward, next_state, complete):
        self.memory.append((observation, move, reward, next_state, complete))

    def train_memory(self, observation, move, reward, next_state, complete):
        self.train()
        torch.set_grad_enabled = True

        state_tensor = torch.tensor(np.expand_dims(observation, 0), dtype=torch.float32, requires_grad=True)
        next_state_tensor = torch.tensor(np.expand_dims(observation, 0), dtype=torch.float32, requires_grad = True)

        if not complete:
            target = reward + self.gamma * torch.max(self.forward(next_state_tensor[0]))

        output = self.forward(state_tensor)
        target_f = output.clone()
        target_f[0][np.argmax(move)] = target
        target_f.detach()
        self.optimizer.zero_grad()
        loss = F.mse_loss(output, target_f)
        loss.backward()
        self.optimizer.step()

    def replay_exp(self):
        if len(self.memory) > self.batch_size:
            minibatch = random.sample(self.memory, self.batch_size)
        else:
            minibatch = self.memory

        for observation, move, reward, next_state, complete in minibatch:
            self.train_memory(observation, move, reward, next_state, complete)

This my game loop:

def play(player, agent):
    state = player.observation()
    print(f"\nplayer {player.index+1}")
    if random.uniform(0,1) < agent.epsilon:
        prediction = torch.rand(60)
        prediction = prediction * player.mask()
    else:
        with torch.no_grad():
            state = state.reshape(1, 57)
            agent.mask = player.mask
            prediction = agent(state)
            print(f"agentPred: {prediction}")

    move = np.argmax(prediction).cpu().detach().numpy().item()

    print(f"move: {move}:{to_cs([move])}")

    player.do_move(move)
    print(f"reward: {player.reward}")

    next_state = player.observation()
    m = np.eye(60)[np.argmax(prediction).numpy()]

    agent.remember(observation=state, move=m, reward=player.reward, next_state=next_state, complete=player.game.complete)

def run():
    agent1 = CardAgent(params=params1)
    agent1.optimizer = optim.Adam(
        agent1.parameters(), weight_decay=0, lr=params1['learning_rate'])
    agent2 = CardAgent(params=params2)
    agent2.optimizer = optim.Adam(
        agent2.parameters(), weight_decay=0, lr=params2['learning_rate'])
    games_count = 0
    steps = 0

    def replay(agent):
        agent.replay_exp()
        model_weights = agent.state_dict()
        torch.save(model_weights, agent.weights_path)

    while games_count < params['episodes']:
        if game.complete:
            steps = 0
            initialize_game(game=game, players=[player1, player2])
            print("\nhands")

            print(to_cs(player1.hand))
            print(to_cs(player2.hand))

            print("\n top card")
            print(cs[game.top_card])

        while not game.complete:
            if game.turn == 0:
                if not params1['train']:
                    agent1.epsilon = 0.01
                else:
                    agent1.epsilon = 1 - (games_count * params1["epsilon_decay_linear"])

                play(player=player1, agent=agent1)
            elif game.turn == 1:
                if not params2['train']:
                    agent2.epsilon = 0.01
                else:
                    agent2.epsilon = 1 - \
                        (games_count * params1["epsilon_decay_linear"])
                play(player=player2, agent=agent2)


            print(f"game: {games_count}.  step: {steps} turn: {game.turn} score: {player1.won} - {player2.won}")
            steps += 1
            if steps>1000:
                game.complete = True
            if game.complete:
                games_count += 1
                replay(agent=agent1)
                replay(agent=agent2)


Solution

  • You are facing this error because you re setting torch.set_grad_enabled to True in def train_memory

    torch.set_grad_enabled = True
    

    Now, when you call this:

    torch.set_grad_enabled(False)
    

    it is equivalent to :

    True(False)
    
    TypeError: 'bool' object is not callable