Search code examples
luatorch

How to check if two Torch tensors or matrices are equal?


I need a Torch command that checks if two tensors have the same content, and returns TRUE if they have the same content.

For example:

local tens_a = torch.Tensor({9,8,7,6});
local tens_b = torch.Tensor({9,8,7,6});

if (tens_a EQUIVALENCE_COMMAND tens_b) then ... end

What should I use in this script instead of EQUIVALENCE_COMMAND ?

I tried simply with == but it does not work.


Solution

  • torch.eq(a, b)
    

    eq() implements the == operator comparing each element in a with b (if b is a value) or each element in a with its corresponding element in b (if b is a tensor).


    Alternative from @deltheil:

    torch.all(tens_a.eq(tens_b))