Search code examples
pythonnumpyfloating-pointuser-defined-data-types

Define a custom float8 in python-numpy and convert from/to float16?


I am trying to define a custom 8 bit floating point format as follows:

  • 1 sign bit
  • 2 bits for mantissa
  • 5 bits for exponent

Is it possible to define this as a numpy datatype? If not, what is the easiest way to convert a numpy array of dtype float16 to such a format (for storage) and convert it back (for calculations in float16), maybe using the bit operations of numpy?

Why:

I am trying to optimize a neural network on custom hardware (FPGA). For this, I am playing around with various float representations. I have already built a forward pass framework for my neural network with numpy, therefore something like above will help me check the reduction in accuracy by storing the values in my custom datatype.


Solution

  • I'm by no means an expert in numpy, but I like to think about FP representation problems. The size of your array is not huge, so any reasonably efficient method should be fine. It doesn't look like there's an 8 bit FP representation, I guess because the precision isn't so good.

    To convert to an array of bytes, each containing a single 8 bit FP value, for a single dimensional array, all you need is

    float16 = np.array([6.3, 2.557])           # Here's some data in an array
    float8s = array.tobytes()[1::2]
    print(float8s)
    >>> b'FAAF'
    

    This just takes the high-order bytes from the 16 bit float by lopping off the low order part, giving a 1 bit sign, 5 bit exponent and 2 bit significand. The high order byte is always the second byte of each pair on a little-endian machine. I've tried it on a 2D array and it works the same. This truncates. Rounding in decimal would be a whole other can of worms.

    Getting back to 16 bits would be just inserting zeros. I found this method by experiment and there is undoubtedly a better way, but this reads the byte array as 8 bit integers and writes a new one as 16 bit integers and then converts it back to an array of floats. Note the big-endian representation converting back to bytes as we want the 8 bit values to be the high order bytes of the integers.

    float16 = np.frombuffer(np.array(np.frombuffer(float8s, dtype='u1'), dtype='>u2').tobytes(), dtype='f2')
    print(float16)
    >>> array([6. , 2.5, 2.5, 6. ], dtype=float16)
    

    You can definitely see the loss of precision! I hope this helps. If this is sufficient, let me know. If not, I'd be up for looking deeper into it.