Search code examples
pythonpytorchcomparison

PyTorch: compare three tensors?


I have three boolean mask tensors that I want to create a boolean mask that if the value matches in three tensors then it is 1, else 0.

I tried torch.where(A == B == C, 1, 0), but it doesn't seem to support such.


Solution

  • The torch.eq operator only supports binary tensor comparisons, hence you need to perform two comparisons:

    (A==B) & (B==C)