Search code examples
pythonvalidationnumpyintrospectiontypechecking

Validation against NumPy dtypes -- what's the least circuitous way to check values?


I want to test an unknown value against the constraints that a given NumPy dtype implies -- e.g., if I have an integer value, is it small enough to fit in a uint8?

As best I can ascertain, NumPy's dtype architecture doesn't offer a way to do something like this:

### FICTIONAL NUMPY CODE: I made this up ###
try:
    numpy.uint8.validate(rupees)
except numpy.dtype.ValidationError:
    print "Users can't hold more than 255 rupees."

My little fantasy API is based on Django's model-field validators, but that's just one example -- the best mechanism I managed to contrive was along the lines of this:

>>> nd = numpy.array([0,0,0,0,0,0], dtype=numpy.dtype('uint8'))
>>> nd[0]
0
>>> nd[0] = 1
>>> nd[0] = -1
>>> nd
array([255,   0,   0,   0,   0,   0], dtype=uint8)
>>> nd[0] = 257
>>> nd
array([1, 0, 0, 0, 0, 0], dtype=uint8)

Round-tripping the questionable values through a numpy.ndarray typed as explicitly numpy.uint8 gives me back integers that have been wrapped to something with an appropriate size -- without tossing an exception, or raising any other sort of actionable error state.

I'd rather not put on the architecture-astronaut flight suit, of course, but that's preferable the alternative, which looks like unmaintainable spaghetti-monster mess of if dtype(this) ... elif dtype(that) statements. Is there anything I can do here besides embarking on the grandiose and indulgent act of writing my own API?


Solution

  • If a is your original iterable, you could do something along the following lines:

    np.all(np.array(a, dtype=np.int8) == a)
    

    Quite simply, this compares the resulting ndarray to the original values, and tells you whether the conversion to ndarray has been lossless.

    This will also catch things like using a floating-point type that's too narrow to represent some of the values exactly:

    >>> a = [0, 0, 0, 0, 0, 0.123456789]
    >>> np.all(np.array(a, dtype=np.float32) == a)
    False
    >>> np.all(np.array(a, dtype=np.float64) == a)
    True
    

    Edit: One caveat when using the above code with floating-point numbers is that NaNs always compare unequal. If required, it is trivial to extend the code to handle that case too.