Search code examples

AttributeError: 'tuple' object has no attribute 'dim', when feeding input to Pytorch LSTM network

I am trying to run the following code:

import matplotlib.pylab as plt
import numpy as np
import torch
import torch.nn as nn

class LSTM(nn.Module):
    def __init__(self, input_shape, n_actions):
        super(LSTM, self).__init__()

        self.lstm = nn.LSTM(input_shape, 12)
        self.hidden2tag = nn.Linear(12, n_actions)

    def forward(self, x):
        out = self.lstm(x)
        out = self.hidden2tag(out)
        return out

state = [(1,2,3,4,5),(2,3,4,5,6),(3,4,5,6,7),(4,5,6,7,8),(5,6,7,8,9),(6,7,8,9,0)]

device = torch.device("cuda")
net = LSTM(5, 3).to(device)

state_v = torch.FloatTensor(state).to(device)

q_vals_v = net(state_v.view(1, state_v.shape[0], state_v.shape[1]))
_, action = int(torch.max(q_vals_v, dim=1).item())

And that returns this error:

Traceback (most recent call last):
  File "/home/dikkerj/Documents/PycharmProjects/LSTMReactor/", line 26, in <module>
    q_vals_v = net(state_v.view(1, state_v.shape[0], state_v.shape[1]))
  File "/home/dikkerj/.local/lib/python3.5/site-packages/torch/nn/modules/", line 477, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/dikkerj/Documents/PycharmProjects/LSTMReactor/", line 15, in forward
    out = self.hidden2tag(out)
  File "/home/dikkerj/.local/lib/python3.5/site-packages/torch/nn/modules/", line 477, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/dikkerj/.local/lib/python3.5/site-packages/torch/nn/modules/", line 55, in forward
    return F.linear(input, self.weight, self.bias)
  File "/home/dikkerj/.local/lib/python3.5/site-packages/torch/nn/", line 1022, in linear
    if input.dim() == 2 and bias is not None:
AttributeError: 'tuple' object has no attribute 'dim'

Anyone knows how to fix this? (to get rid of the tensor being a tuple so that it can be fed into the LSTM network)


  • The pytorch LSTM returns a tuple.
    So you get this error as your linear layer self.hidden2tag can not handle this tuple.

    So change:

    out = self.lstm(x)


    out, states = self.lstm(x)

    This will fix your error, by splitting up the tuple so that out is just your output tensor.

    out then stores the hidden states, while states is another tuple that contains the last hidden and cell state.

    You can also take a look here:

    You will get another error for the last line as max() returns a tuple as well. But this should be easy to fix and is yet different error :)