Search code examples
pytorchtensortorch

why torch.Tensor subtract works well when tensor size is different?


This example will make it easier to understand. The following fails:

A = tensor.torch([[1, 2, 3], [4, 5, 6]])   # shape : (2, 3)
B = tensor.torch([[1, 2], [3, 4], [5, 6]]) # shape : (3, 2)
print((A - B).shape)

# RuntimeError: The size of tensor A (3) must match the size of tensor B (2) at non-singleton dimension 1
# ==================================================================
A = tensor.torch([[1, 2], [3, 4], [5, 6]])   # shape : (3, 2)
B = tensor.torch([[1, 2], [3, 4],]) # shape : (2, 2)
print((A - B).shape)

# RuntimeError: The size of tensor A (3) must match the size of tensor B (2) at non-singleton dimension 0

But the following works well:

a = torch.ones(8).unsqueeze(0).unsqueeze(-1).expand(4, 8, 7) 
a_temp = a.unsqueeze(2)                            # shape : ( 4, 8, 1, 7 )
b_temp = torch.transpose(a_temp, 1, 2)             # shape : ( 4, 1, 8, 7 )
print(a_temp-b_temp)                               # shape : ( 4, 8, 8, 7 )

Why does the latter work, but not the former?
How/why has the result shape been expanded?


Solution

  • This is well explained by the broadcasting semantics. The important part is :

    Two tensors are “broadcastable” if the following rules hold:

    • Each tensor has at least one dimension.
    • When iterating over the dimension sizes, starting at the trailing dimension, the dimension sizes must either be equal, one of them is 1, or one of them does not exist.

    In your case, (3,2) and (2,3) cannot be broadcast to a common shape (3 != 2 and neither are equal to 1), but (4,8,1,7), (4,1,8,7) and (4,8,8,7) are broadcast compatible.

    This is basically what the error states : all dimensions must be either equal ("match") or singletons (i.e equal to 1)

    What happens when the shape are broadcasted is basically a tensor expansion to make the shape match (expand to [4,8,8,7]), and then perform the subtraction as usual. Expansion duplicates your data (in a smart way) to reach the required shape.