Search code examples
pythoncontainsequals-operator

Python: Strange behaviour of __eq__ and __contains__ on a custom class


I have the following custom classes (stripped down some), implementing an expression tree:

from abc import ABC, abstractmethod


class Operator:
    def __init__(self, str_reps):
        self.str_reps = str_reps

    def __str__(self):
        return self.str_reps[0]

    def __eq__(self, other):
        return self is other

    def __ne__(self, other):
        return self is not other

    def __hash__(self):
        return hash(str(self))


NOT = Operator(["¬", "~", "not", "!"])
AND = Operator(["∧", "&", "and"])
OR = Operator(["∨", "|", "or"])


class Node(ABC):
    @abstractmethod
    def __eq__(self, other):
        pass

    @abstractmethod
    def __hash__(self):
        pass

    def __ne__(self, other):
        return not self == other

    @abstractmethod
    def __str__(self):
        pass

    @abstractmethod
    def __invert__(self):
        pass

    def bracket_if_necessary(self):
        return str(self)
    

class Leaf(Node):
    def __init__(self, v):
        self.val = v

    def __eq__(self, other):
        if not isinstance(other, Leaf):
            return False
        return self.val == other.val

    def __hash__(self):
        return hash(self.val)

    def __str__(self):
        return str(self.val)

    def __invert__(self):
        return UnaryNode(self)


class UnaryNode(Node):
    def __init__(self, child):
        self.child = child
        self.hash = hash(NOT) + hash(self.child)

    def __eq__(self, other):
        if not isinstance(other, UnaryNode):
            return False
        return self.child == other.child

    def __hash__(self):
        return self.hash

    def __str__(self):
        return str(NOT) + self.child.bracket_if_necessary()

    def __invert__(self):
        return self.child


class VariadicNode(Node):
    def __init__(self, op, children):
        self.op = op
        self.children = children
        self.hash = hash(self.op) + sum(hash(child) for child in self.children)

    def __eq__(self, other):
        if not isinstance(other, VariadicNode):
            return False
        return self.op is other.op and set(self.children) == set(other.children)

    def __hash__(self):
        return self.hash

    def __str__(self):
        return (" " + str(self.op) + " ").join(child.bracket_if_necessary() for child in self)

    def __invert__(self):
        return VariadicNode(AND if self.op is OR else OR, tuple(~c for c in self))

    def bracket_if_necessary(self):
        return "(" + str(self) + ")"

    def __iter__(self):
        return iter(self.children)

    def __contains__(self, item):
        return item in self.children

If I run this and try things like

Leaf("36") == Leaf("36)
~Leaf("36") == ~Leaf("36")
~~Leaf("36") == Leaf("36")

they all return True, as expected.

However, I'm running into bugs in the code that utilizes these nodes:

# Simplify procedure in DPLL Algorithm
def _simplify(cnf, l):
    # print for debugging
    for c in cnf:
        print("b", l, ~l, type(l), "B", c, type(c), l in c)

    return VariadicNode(AND, tuple(_filter(c, ~l) for c in cnf if l not in c))


# Removes the chosen unit literal (negated above) from clause c
def _filter(c, l):
    # print for debugging
    for x in c:
        print("a", l, type(l), "A", x, type(x), x==l)

    return VariadicNode(c.op, tuple(x for x in c if x != l))

Here cnf is given as a VariadicNode(AND) with all children being VariadicNode(OR). Children of VariadicNode are always given as a tuple.

These two prints result in lines like:

a ¬25 <class 'operators.UnaryNode'> A ¬25 <class 'operators.UnaryNode'> False
b ¬25 25 <class 'operators.UnaryNode'> B ¬25 ∨ ¬36 <class 'operators.VariadicNode'> False

which should not happen (¬25 == ¬25 in the first line and ¬25 in (¬25 ∨ ¬36) in the second should both return True). However there is also a line in the output:

b ¬25 25 <class 'operators.UnaryNode'> B ¬25 <class 'operators.VariadicNode'> True

so the check ¬25 in (¬25) actually does return True as it should.

Can anyone tell me what's going on?

If more info is needed, the rest of the code is available at [deleted] (hopefully it's publicly available, I'm pretty new to github so I don't know their policies). Note that it is still a WIP though.

The classes are located in operators.py and the rest of the (relevant) code is in SAT_solver.py, while test.py allows for easy running of the entire project, provided the networkx library is installed.

EDIT

I've now pushed a sample dimacs .txt file to the github repository that results in the described problem. Simply download hamilton.txt, SAT_solver.py and operators.py from the repository to the same folder, run SAT_solver.py (which has a main() method) and input hamilton or hamilton.txt to the command line when prompted for the problem file name (just leave the solution file name empty when prompted to prevent the program from writing any files). This should result in a lot of output, including problematic lines as described above.


Solution

  • The code you posted is not the code in your github repo. Your code has UnaryNode.__eq__ of

    def __eq__(self, other):
        if not isinstance(other, UnaryNode):
            return False
        return self.child == other.child
    

    the repo code has

    def __eq__(self, other):
        if not isinstance(other, UnaryNode):
            return False
        return self.op is other.op and self.child == other.child
    

    which also requires that the operators are identical. Instrumenting your code and breaking at a failure shows that you're generating two different NOT operators somewhere:

    >>> str(x)
    '¬24'
    >>> str(l)
    '¬24'
    >>> x == l
    False
    >>> x.child == l.child
    True
    >>> str(x.op)
    '¬'
    >>> str(l.op)
    '¬'
    >>> x.op == l.op
    False
    >>> id(x.op)
    2975964300
    >>> id(l.op)
    2976527276
    

    Track down where that's happening and fix it however you like, whether by avoiding having more than one or by not caring whether there are more than one (my preference). I know in a deleted comment you wrote "The operators are single objects and I want them to be compared for equality based on their references", but (1) they're not single objects, and (2) if you didn't want such an unnecessary thing you wouldn't have woudn up in trouble..

    If I had to guess, the other operators are being introduced when you call copy.deepcopy, but you're definitely not working with singletons.