Search code examples

Pytorch: the number of sizes provided (1) must be greater or equal to the number of dimensions in the tensor (3)

I am trying to use a prioritized replay buffer for my dqn agent. The problem I encounter is following.

I have a world which has (40, 40, 1) state representation. When I try to add a transition into the buffer, it gives me :

RuntimeError: expand(torch.DoubleTensor{[40, 40, 1]}, size=[3]): the number of sizes provided (1) must be greater or equal to the number of dimensions in the tensor (3)

The Prioritized Replay buffer code:

class PrioritizedReplayBuffer:
    def __init__(self, state_size=3, action_size=1, buffer_size=10000, eps=1e-2, alpha=0.1, beta=0.1):
        self.tree = SumTree(size=buffer_size)

        # PER params
        self.eps = eps 
        self.alpha = alpha
        self.beta = beta
        self.max_priority = eps

        # transition: state, action, reward, next_state, done
        self.state = torch.empty(buffer_size, state_size, dtype=torch.float)
        self.action = torch.empty(buffer_size, action_size, dtype=torch.float)
        self.reward = torch.empty(buffer_size, dtype=torch.float)
        self.next_state = torch.empty(buffer_size, state_size, dtype=torch.float)
        self.done = torch.empty(buffer_size,

        self.count = 0
        self.real_size = 0
        self.size = buffer_size

    def add(self, transition):
        state, action, reward, next_state, done = transition

        # store transition index with maximum priority in sum tree
        self.tree.add(self.max_priority, self.count)

        # store transition in the buffer
        self.state[self.count] = torch.as_tensor(state)
        self.action[self.count] = torch.as_tensor(action)
        self.reward[self.count] = torch.as_tensor(reward)
        self.next_state[self.count] = torch.as_tensor(next_state)
        self.done[self.count] = torch.as_tensor(done)

        # update counters
        self.count = (self.count + 1) % self.size
        self.real_size = min(self.size, self.real_size + 1)

Any help would be appreciated. Thanks


  • The problem is solved by:

    # transition: state, action, reward, next_state, done
            self.state = torch.empty((buffer_size, 40, 40, 1),  dtype=torch.float)
            self.action = torch.empty(buffer_size, action_size, dtype=torch.float)
            self.reward = torch.empty(buffer_size, dtype=torch.float)
            self.next_state = torch.empty((buffer_size, 40, 40, 1), dtype=torch.float)
            self.done = torch.empty(buffer_size,