Search code examples
matrixpytorchmasktensorbinary-matrix

Masking the columns of pytorch matrix


I have a matrix of shape (batch_size, N, N) and a masking tensor of shape (batch_size, N).

I want to put -infinity values only for the columns (and not rows) of the matrix, according to the given mask.


Solution

  • The solution is via repeating the mask in the correct dimension:

    repeated_mask[mask.unsqueeze(-1).repeat(1,1,mask.shape[-1])!=1] = 0