Search code examples
pytorchtensor

Expand the tensor by several dimensions


In PyTorch, given a tensor (vector) of length n, how to expand it by several dimensions augmenting and copying each entry in the tensor to those dimensions? For example, given a tensor of size (3) expand it to the size=(3,2,5,5) such that the added dimensions have the corresponding values from the original tensor. In this case, let the size=(3) and the vector=[1,2,3] such that the first tensor of size (2,5,5) has values 1, the second one has all values 2, and the third one all values 3.

In addition, how to expand the vector of size (3,2) to (3,2,5,5)?

One way to do it I can think is by means of creating a vector of the same size with ones-Like and then einsum but I think there should be an easier way.


Solution

  • You can first unsqueeze the appropriate number of singleton dimensions, then expand to a view at the target shape with torch.Tensor.expand:

    >>> x = torch.rand(3)
    >>> target = [3,2,5,5]
    
    >>> x[:, None, None, None].expand(target)
    

    A nice workaround is to use torch.Tensor.reshape or torch.Tensor.view to do perform multiple unsqueezing:

    >>> x.view(-1, 1, 1, 1).expand(target)
    

    This allows for a more general approach to handle any arbitrary target shape:

    >>> x.view(len(x), *(1,)*(len(target)-1)).expand(target)
    

    For an even more general implementation, where x can be multi-dimensional:

    >>> x = torch.rand(3, 2)
    
    # just to make sure the target shape is valid w.r.t to x
    >>> assert list(x.shape) == list(target[:x.ndim])
    
    >>> x.view(*x.shape, *(1,)*(len(target)-x.ndim)).expand(target)