Search code examples
pythonpytorchopenai-gym

Convert Pytorch Float Model into Double


I'm trying to solve cartpole from Gym. It turns out that the states are in double floating point precision whereas the pytorch by default creates model in single floating point precision.

class QNetworkMLP(Module):
    def __init__(self,state_dim,num_actions):
        super(QNetworkMLP,self).__init__()
        self.l1 = Linear(state_dim,64)
        self.l2 = Linear(64,64)
        self.l3 = Linear(64,128)
        self.l4 = Linear(128,num_actions)
        self.relu = ReLU()
        self.lrelu = LeakyReLU()
    
    def forward(self,x) :
        x = self.lrelu(self.l1(x))
        x = self.lrelu(self.l2(x))
        x = self.lrelu(self.l3(x))
        x = self.l4(x)
        return x

I tried to convert it via

model = QNetworkMLP(4,2).double()

But it still doesn't work I get the same error.

File ".\agent.py", line 117, in update_online_network
    predicted_Qval = self.online_network(states_batch).gather(1,actions_batch)
  File "C:\Users\27abh\anaconda3\envs\gym\lib\site-packages\torch\nn\modules\module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "C:\Users\27abh\Desktop\OpenAI Gym\Cartpole\agent_model.py", line 16, in forward
    x = self.lrelu(self.l1(x))
  File "C:\Users\27abh\anaconda3\envs\gym\lib\site-packages\torch\nn\modules\module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "C:\Users\27abh\anaconda3\envs\gym\lib\site-packages\torch\nn\modules\linear.py", line 91, in forward
    return F.linear(input, self.weight, self.bias)
  File "C:\Users\27abh\anaconda3\envs\gym\lib\site-packages\torch\nn\functional.py", line 1674, in linear
    ret = torch.addmm(bias, input, weight.t())
RuntimeError: Expected object of scalar type Double but got scalar type Float for argument #2 'mat1' in call to _th_addmm

Solution

  • Can you try this after initializing your model:

     model.to(torch.double)
    

    Also be sure to check if all your inputs to the model are of torch.double data type