Search code examples
pythonclassequalityslots

Equality of Python classes using slots


Another question provides a nice, simple solution for implementing a test for equality of objects. I'll repeat the answer for context:

class CommonEqualityMixin(object):

    def __eq__(self, other):
        return (isinstance(other, self.__class__)
            and self.__dict__ == other.__dict__)

    def __ne__(self, other):
        return not self.__eq__(other)

class Foo(CommonEqualityMixin):

    def __init__(self, item):
        self.item = item

I would like to do this for a class that uses __slots__. I understand that both the base class and the subclass will have to use slots, but how would you define __eq__ for this to work with slots?


Solution

  • import operator
    
    class CommonEqualityMixin(object):
    
        __slots__ = ()
    
        def __eq__(self, other):
            if isinstance(other, self.__class__):
                if self.__slots__ == other.__slots__:
                     attr_getters = [operator.attrgetter(attr) for attr in self.__slots__]
                     return all(getter(self) == getter(other) for getter in attr_getters)
    
            return False
    
        def __ne__(self, other):
            return not self.__eq__(other)
    

    An example of usage:

    class Foo(CommonEqualityMixin):
        __slots__ = ('a', )
        def __init__(self, a):
            self.a = a
    
    Foo(1) == Foo(2)
    # False
    Foo(1) == Foo(1)
    # True
    

    N.B: be aware thow the __slots__ don't get inherited it's not like __dict__ so if for example a new class FooBar inherit from Foo the code above will not work

    Example :

    class FooBar(Foo):
        __slots__ = ('z')
        def __init__(self, a, z):
            self.z = z
            super(FooBar, self).__init__(a)
    
    FooBar(1, 1) == FooBar(2, 1)
    # True
    
    print FooBar(1, 1).__slots__
    # 'z'