Search code examples
pythonfloating-pointcomplex-numbersapproximatestructural-pattern-matching

How to perform approximate structural pattern matching for floats and complex


I've read about and understand floating point round-off issues such as:

>>> sum([0.1] * 10) == 1.0
False

>>> 1.1 + 2.2 == 3.3
False

>>> sin(radians(45)) == sqrt(2) / 2
False

I also know how to work around these issues with math.isclose() and cmath.isclose().

The question is how to apply those work arounds to Python's match/case statement. I would like this to work:

match 1.1 + 2.2:
    case 3.3:
        print('hit!')  # currently, this doesn't match

Solution

  • The key to the solution is to build a wrapper that overrides the __eq__ method and replaces it with an approximate match:

    import cmath
    
    class Approximately(complex):
    
        def __new__(cls, x, /, **kwargs):
            result = complex.__new__(cls, x)
            result.kwargs = kwargs
            return result
    
        def __eq__(self, other):
            try:
                return isclose(self, other, **self.kwargs)
            except TypeError:
                return NotImplemented
    

    It creates approximate equality tests for both float values and complex values:

    >>> Approximately(1.1 + 2.2) == 3.3
    True
    >>> Approximately(1.1 + 2.2, abs_tol=0.2) == 3.4
    True
    >>> Approximately(1.1j + 2.2j) == 0.0 + 3.3j
    True
    

    Here is how to use it in a match/case statement:

    for x in [sum([0.1] * 10), 1.1 + 2.2, sin(radians(45))]:
        match Approximately(x):
            case 1.0:
                print(x, 'sums to about 1.0')
            case 3.3:
                print(x, 'sums to about 3.3')
            case 0.7071067811865475:
                print(x, 'is close to sqrt(2) / 2')
            case _:
                print('Mismatch')
    

    This outputs:

    0.9999999999999999 sums to about 1.0
    3.3000000000000003 sums to about 3.3
    0.7071067811865475 is close to sqrt(2) / 2