Search code examples
pythonpytorchsize

What would be the returned values from torch.Torch.size() apart from height and width?


I'm using a module (SwimIR) for image super resolution based on PyTorch

One of the operations that is done is:

b, c, h, w = img_lq.size()

E = torch.zeros(b, c, h*scale_factor, w*scale_factor).type_as(img_lq)
W = torch.zeros_like(E)

I.e., getting the tensor shape and creating a new tensor of zeros of the same shape. I'm trying to understand what would be the two assigned variables b and c returned by the method size() apart from h and w that are height and width respectively.

Looking in PyTorch documentation, I haven't found anything that could give me a clue about what are those 2 variables.


Solution

  • The torch.Tensor.size function returns the shape of the tensor. As would accessing torch.Tensor.shape directly.

    In other words the first line is simply unpacking the dimension sizes of img_lq as: b ("batch size"), c ("channels"), h ("height"), and w ("width"). Of course, this code makes the assumption that img_lq has exactly four dimensions.