Search code examples
pythonpytorchslice

How to combine slice assignment, mask assignment and broadcasting in PyTorch?


To be more specific, I'm wondering how to assign a tensor by slice and by mask at different dimension(s) simultaneously in PyTorch. Here's a small example about what I want to do: With the tensors and masks below:

x = torch.zeros(2, 3, 4, 6)
mask = torch.tensor([[ True, True, False], [True, False, True]])
y = torch.rand(2, 3, 1, 3)

I want to achieve something like

x[mask, :, :3] = y[mask]
  • In dimension 0 and 1, only the 4x6/1x3 slices in x/y that whose corresponding element in mask is True are allowed to be assigned.
  • In dimension 2, I hope the 1-row tensor in y can be broadcast to all the 8 rows of x,
  • In dimension 3, only the first 3 elements in x are assigned with the 3-element tensor from y.

However, with code above, following error was caught:

RuntimeError: shape mismatch: value tensor of shape [4, 1, 3] cannot be broadcast to indexing result of shape [4, 3, 6]

It seems that PyTorch did [mask] indexing instead, and ignored the :3 indexing.

I've also tried

x[mask][:, :, :3] = y[mask]

No error occurred but the assignment still failed.

I know I can assign by slice and by mask step by step, but I hope to avoid any intermediate tensors if possible. Tensors in neural networks may be extremely big, so may be an all-in-one assignment may take less time and less memory.


Solution

  • You can do the following:

    x = torch.zeros(2, 3, 4, 6)
    mask = torch.tensor([[ True, True, False], [True, False, True]])
    y = torch.rand(2, 3, 1, 3)
    x[..., :3][mask] = y[mask]
    

    This produces the same result as

    i, j = mask.nonzero(as_tuple = True)
    x[i, j, :, :3] = y[i, j]
    

    For the 2D mask scenario. This method also works for additional dims:

    x = torch.zeros(2, 3, 3, 4, 6)
    y = torch.rand(2, 3, 3, 1, 3)
    mask = torch.rand(2,3,3)>0.5
    x[..., :3][mask] = y[mask]
    

    For additional dims, the only constraint is that the first n=mask.ndim dims of x and y must match the shape of mask and the final dimension of y is 3 to match the :3.