Search code examples
python-3.xalgorithm2-satisfiability

Implementing an efficient 2-SAT solving algorithm


I was reading about the 2-SAT problem on Wikipedia and I was wondering what the O(n) algorithm looks like in Python.

So far I've only found implementations that either are in other programming languages or that just determine whether an expression has solutions or not, without given the solution itself.

How could the O(n) algorithm for finding the values of variables be written in Python?


Solution

  • Here is an OOP implementation in Python:

    import re
    
    class TwoSat:
        class Variable:
            def __init__(self, name, negated=None):
                self.name = name
                self.negated = negated or TwoSat.Variable("~" + name, self)
                self.implies = set()
                self.impliedby = set()
                self.component = -1
                
            def disjunction(self, b):
                self.negated.implication(b)
                b.negated.implication(self)
        
            def implication(self, b):
                self.implies.add(b)
                b.impliedby.add(self)
        
            def postorder(self, visited):
                if self not in visited:
                    visited.add(self)
                    for neighbor in self.implies:
                        yield from neighbor.postorder(visited)                    
                    yield self
        
            def setcomponent(self, component):
                if self.component == -1:
                    self.component = component
                    for neighbor in self.impliedby:
                        neighbor.setcomponent(component)
        
            def value(self):
                diff = self.component - self.negated.component
                return diff > 0 if diff else None
                
            ### end of class Variable 
                
        def __init__(self, s):
            self.variables = {}
            for a_neg, a_name, b_neg, b_name in re.findall(r"(~)?(\w+).*?(~)?(\w+)", s):
                self.getvariable(a_neg, a_name).disjunction(self.getvariable(b_neg, b_name))
    
        def getvariable(self, neg, name):
            if name not in self.variables:
                self.variables[name] = TwoSat.Variable(name)
                self.variables["~" + name] = self.variables[name].negated
            a = self.variables[name]
            return a.negated if neg else a
    
        def postorder(self):
            visited = set()
            for startvariable in self.variables.values():
                yield from startvariable.postorder(visited)
        
        def setcomponents(self):                    
            for i, variable in enumerate(reversed(list(self.postorder()))):
                variable.setcomponent(i)
    
        def issolved(self):
            return all(variable.value() is not None for variable in self.variables.values())
        
        def solve(self):
            self.setcomponents()
            return self.issolved()
    
        def truevariables(self):
            if self.issolved():
                return [variable.name for variable in self.variables.values() if variable.value()]
        
        def __repr__(self):
            return " ∧ ".join(
                f"({a.name} → {b.name})"
                for a in self.variables.values()
                for b in a.implies
            )
    

    Here is an example of how this class can be used:

    problem = TwoSat("(~a+~b)*(b+~c)*(c+g)*(d+a)*(~f+i)*(~i+~j)*(~h+d)*(~d+~b)*(~f+c)*(h+~i)*(i+~g)")
    print(problem)
    problem.solve()
    print("solution: ", problem.truevariables())
    

    The TwoSat constructor takes one argument, a string, which should provide the conjugation of disjunction pairs. The syntax rules for this string are:

    • literals must use alphanumeric characters (underscores allowed), representing a variable, optionally prefixed with a ~ to denote negation.
    • All other characters are just taken as separators and are not validated.
    • All literals are taken in pairs and each consecutive pair is assumed to form a disjunction clause.
    • If the number of literals is odd, then although that expression is not a valid 2SAT expression, the last literal is simply ignored.

    So the above example could also have taken this string representing the same problem:

    problem = TwoSat("~a ~b b ~c c g d a ~f i ~i ~j ~h d ~d ~b ~f c h ~i i ~g")
    

    Alternatively, you can use the getvariable and disjunction methods to build the expression. Look at the __init__ method how the constructor uses those methods when parsing the string. For example:

    problem = TwoSat()
    for variable in "abcdefghij":
        problem.getvariable(False, variable)
    # Define the disjunction ~a + ~b:
    problem.variables["a"].negated.disjunction(problem.variables["b"].negated)
    # ...etc
    

    The algorithm is the one explained in the 2-satisiability article on Wikipedia, identifying strongly connected components using Kosaraju's algorithm