I'm given a PyTorch boolean tensor of size NxK
(K>=2
), which is guaranteed to have at least 2 True
values in each row.
For each row, I want to get the indices of 2 random cells with True
values.
For example, let's say I have the following tensor:
tensor([[False, True, False, False, True],
[ True, True, True, False, True],
[ True, True, False, True, False],
[False, True, False, False, True],
[False, True, False, False, True]])
So a possible output would be:
tensor([[1, 4],
[4, 2],
[0, 3],
[4, 1],
[4, 1]])
Because on the first row, we have True
value on column 1 and 4, and on the second row we have True
value on column 4 and 2, and so on.
Another possible output would be:
tensor([[4, 1],
[1, 4],
[0, 1],
[1, 4],
[1, 4]])
Because now a different random True
values was selected for each row.
I've currently implemented a pytorch-numpy-pytorch solution for that, using np.random.choice
:
available_trues = [t.nonzero(as_tuple=False).flatten() for t in input_tensor]
only_2_trues = [np.random.choice(t.cpu(), size=(2,), replace=False) for t in available_trues]
only_2_trues = torch.from_numpy(np.stack(only_2_trues)).cuda()
But this solution is not vectorized (since it works using lists) and it requires moving the data back and forth between the CPU and the GPU. Since I'm working on large matrices and running this operation many times, this causes a major slowdown.
I wonder if it can be done without list comprehension or without moving the data (or even without both :] )
Initially I was thinking if this can be done with torch.topk
but then realized that it will be deterministic.
I was then able to make it work using torch.multinomial
with some extra setup of a probability matrix.
Assume the matrix is m
.
p = matrix / matrix.sum(axis=1).unsqueeze(1)
tensor([[0.0000, 0.5000, 0.0000, 0.0000, 0.5000],
[0.2500, 0.2500, 0.2500, 0.0000, 0.2500],
[0.3333, 0.3333, 0.0000, 0.3333, 0.0000],
[0.0000, 0.5000, 0.0000, 0.0000, 0.5000],
[0.0000, 0.5000, 0.0000, 0.0000, 0.5000]])
Then
p.multinomial(num_samples=2)
tensor([[4, 1],
[2, 4],
[1, 0],
[1, 4],
[1, 4]])
Each time you run, you get different results.
Obviously you can combine the above steps into one, I'm just showcasing what exact is the p
matrix doing.