I'm trying to solve cartpole from Gym. It turns out that the states are in double floating point precision whereas the pytorch by default creates model in single floating point precision.
class QNetworkMLP(Module):
def __init__(self,state_dim,num_actions):
super(QNetworkMLP,self).__init__()
self.l1 = Linear(state_dim,64)
self.l2 = Linear(64,64)
self.l3 = Linear(64,128)
self.l4 = Linear(128,num_actions)
self.relu = ReLU()
self.lrelu = LeakyReLU()
def forward(self,x) :
x = self.lrelu(self.l1(x))
x = self.lrelu(self.l2(x))
x = self.lrelu(self.l3(x))
x = self.l4(x)
return x
I tried to convert it via
model = QNetworkMLP(4,2).double()
But it still doesn't work I get the same error.
File ".\agent.py", line 117, in update_online_network
predicted_Qval = self.online_network(states_batch).gather(1,actions_batch)
File "C:\Users\27abh\anaconda3\envs\gym\lib\site-packages\torch\nn\modules\module.py", line 722, in _call_impl
result = self.forward(*input, **kwargs)
File "C:\Users\27abh\Desktop\OpenAI Gym\Cartpole\agent_model.py", line 16, in forward
x = self.lrelu(self.l1(x))
File "C:\Users\27abh\anaconda3\envs\gym\lib\site-packages\torch\nn\modules\module.py", line 722, in _call_impl
result = self.forward(*input, **kwargs)
File "C:\Users\27abh\anaconda3\envs\gym\lib\site-packages\torch\nn\modules\linear.py", line 91, in forward
return F.linear(input, self.weight, self.bias)
File "C:\Users\27abh\anaconda3\envs\gym\lib\site-packages\torch\nn\functional.py", line 1674, in linear
ret = torch.addmm(bias, input, weight.t())
RuntimeError: Expected object of scalar type Double but got scalar type Float for argument #2 'mat1' in call to _th_addmm
Can you try this after initializing your model:
model.to(torch.double)
Also be sure to check if all your inputs to the model are of torch.double data type