I have the following two PyTorch tensors A and B.
A = torch.tensor(np.array([40, 42, 38]), dtype = torch.float64)
tensor([40., 42., 38.], dtype=torch.float64)
B = torch.tensor(np.array([[[1,2,3,4,5],[1,2,3,4,5],[1,2,3,4,5],[1,2,3,4,5],[1,2,3,4,5]], [[4,5,6,7,8],[4,5,6,7,8],[4,5,6,7,8],[4,5,6,7,8],[4,5,6,7,8]], [[7,8,9,10,11],[7,8,9,10,11],[7,8,9,10,11],[7,8,9,10,11],[7,8,9,10,11]]]), dtype = torch.float64)
tensor([[[ 1., 2., 3., 4., 5.],
[ 1., 2., 3., 4., 5.],
[ 1., 2., 3., 4., 5.],
[ 1., 2., 3., 4., 5.],
[ 1., 2., 3., 4., 5.]],
[[ 4., 5., 6., 7., 8.],
[ 4., 5., 6., 7., 8.],
[ 4., 5., 6., 7., 8.],
[ 4., 5., 6., 7., 8.],
[ 4., 5., 6., 7., 8.]],
[[ 7., 8., 9., 10., 11.],
[ 7., 8., 9., 10., 11.],
[ 7., 8., 9., 10., 11.],
[ 7., 8., 9., 10., 11.],
[ 7., 8., 9., 10., 11.]]], dtype=torch.float64)
Tensor A is of shape:
Tensor B is of shape:
torch.Size([3, 5, 5])
How do I multiply tensor A with tensor B (using broadcasting) in such a way for eg. the first value in tensor A (ie. 40.
) is multiplied with all the values in the first 'nested' tensor in tensor B, ie.
tensor([[[ 1., 2., 3., 4., 5.],
[ 1., 2., 3., 4., 5.],
[ 1., 2., 3., 4., 5.],
[ 1., 2., 3., 4., 5.],
[ 1., 2., 3., 4., 5.]],
and so on for the other 2 values in tensor A and the other two nested tensors in tensor B, respectively.
I could do this multiplication (via broadcasting) with numpy arrays if A and B are arrays of both shape (3,) - ie. A*B
- but I can't seem to figure out a counterpart of this with PyTorch tensors. Any help would really be appreciated.
When applying broadcasting in pytorch (as well as in numpy) you need to start at the last dimension (check out https://pytorch.org/docs/stable/notes/broadcasting.html). If they do not match you need to reshape your tensor. In your case they can't directly be broadcasted:
[3] # the two values in the last dimensions are not one and do not match
[3, 5, 5]
Instead you can redefine A = A[:, None, None]
before muliplying such that you get shapes
[3, 1, 1]
[3, 5, 5]
which satisfies the conditions for broadcasting.