Search code examples
pythonnumpycomparisonoperatorscustom-data-type

Understanding Array comparison while using custom data types with numpy


i have a question regarding comparison with numpy arrays while using cusom data types.

Here is my code:

import numpy as np


class Expr:

    def __add__(self, other):
        return Add(self, other)

    def __eq__(self, other):
        return Eq(self, other)


class Variable(Expr):

    def __init__(self, name):

        self.name = name

    def __repr__(self):
        return self.name


class Operator(Expr):

    def __init__(self, left, right):

        self.left = left
        self.right = right

    def __repr__(self):
        return f'{self.__class__.__name__}({self.left}, {self.right})'


class Add(Operator):
    ...


class Eq(Operator):
    ...


if __name__ == '__main__':

    arr1 = np.array([Variable('v1'), Variable('v2')])
    arr2 = np.array([Variable('v3'), Variable('v4')])

    print(arr1 + arr2)

    print(arr1 == arr2)

The output is:

[Add(v1, v3) Add(v2, v4)]
[ True  True]

I dont get why the equal comparison does not return [Eq(v1, v3) Eq(v2, v4)], because my code works for the Addition. How do i make this work.

Thank you!


Solution

  • == on arrays produces an output of boolean dtype. The result array physically cannot hold Eq instances; it can only hold booleans. Your Eq instances get converted to booleans, producing True, since that's the boolean value of any object that doesn't define __len__ or __bool__.

    If you want an array of Eq instances, you need an output of object dtype. You can get that by specifying a dtype for numpy.equal:

    res = numpy.equal(arr1, arr2, dtype=object)