Search code examples
pythonnumpyfloating-pointfloating-point-comparison

Checking if a specific float value is in list/array in Python/numpy


Care needs to be taken when checking for equality between floating point numbers, and should usually be done with a tolerance in mind, using e.g. numpy.allcose.

Question 1: Is it safe to check for the occurrence of a specific floating point number using the "in" keyword (or are there similar keywords/functions for this purpose)? Example:

if myFloatNumber in myListOfFloats:
  print('Found it!')
else:
  print('Sorry, no luck.')

Question 2: If not, what would be a neat and tidy solution?


Solution

  • If you don't compute your floats in the same place or with the exact same equation, then you might have false negatives with this code (because of rounding errors). For example:

    >>> 0.1 + 0.2 in [0.6/2, 0.3]  # We may want this to be True
    False
    

    In this case, we can just have a custom "in" function that will actually make this true (in this case it may be better/faster to use numpy.isclose instead of numpy.allclose):

    import numpy as np 
    
    def close_to_any(a, floats, **kwargs):
      return np.any(np.isclose(a, floats, **kwargs))
    

    There is an important note in the documentation:

    Warning The default atol is not appropriate for comparing numbers that are much smaller than one (see Notes). [...] if the expected values are significantly smaller than one, it can result in false positives.

    The note adds that atol is not zero contrary to math.isclose's abs_tol. If you need a custom tolerance when using close_to_any, use the kwargs to pass rtol and/or atol down to numpy. In the end, your existing code would translate to this:

    if close_to_any(myFloatNumber, myListOfFloats):
      print('Found it!')
    else:
      print('Sorry, no luck.')
    

    Or you could have some options close_to_any(myFloatNumber, myListOfFloats, atol=1e-12), note that 1e-12 is arbitrary and you shouldn't use this value unless you have a good reason to.

    Coming back to the rounding error we observed in the first example, this would give:

    >>> close_to_any(0.1 + 0.2, [0.6/2, 0.3])
    True