Search code examples
pytorchreshape

Explain (T,) tensor shape


In the following d2l tutorial:

import torch

T = 1000

time = torch.arange(1, T + 1, dtype=torch.float32)
print(f"time shape: {time.shape}")

x = torch.sin(0.01 * time) + torch.normal(0.0, 0.2, size=(T,))

Given that torch.sin(0.01 * time) shape is torch.Size([1000]) why the size attribute provided to normal function is (T,) and not (T)?


Solution

  • Because (T) is equal to T having type of int, while torch.normal requires a tuple. (T,) is a Python way to pass one-element tuple.