Search code examples
pytorchelementwise-operations

Element wise multiplication of scalars with matrix


In this example, I want to multiply each of the 10 (batch size) 3x3 matrices with the corresponding scalar. Is there a better solution without having to unsqueeze twice?

import torch

# Create a batch of scalars (e.g., 10 scalars)
batch_size = 10
scalars = torch.randn(batch_size)  # Shape: (10,)

# Create a batch of 3x3 second-order tensors (e.g., 10 matrices of size 3x3)
tensors = torch.randn(batch_size, 3, 3)  # Shape: (10, 3, 3)

# Do we really have to unsqueeze twice to perform element wise multiplication?
result = scalars.unsqueeze(1).unsqueeze(2) * tensors  # Shape (10, 3, 3)

Solution

  • You can use different syntax to create a broadcastable view. All of the following are equivalent.

    result = scalars.unsqueeze(1).unsqueeze(2) * tensors
    result = scalars.view(-1, 1, 1) * tensors
    result = scalars[:,None,None] * tensors