Search code examples
pythonfloating-pointctypes

Half-precision in ctypes


I need to be able to seamlessly interact with half-precision floating-point values in a ctypes structure. I have a working solution, but I'm dissatisfied with it:

import ctypes
import struct


packed = struct.pack('<Ife', 4, 2.3, 1.2)
print('Packed:', packed.hex())


class c_half(ctypes.c_ubyte*2):
    @property
    def value(self) -> float:
        result, = struct.unpack('e', self)
        return result


class Triple(ctypes.LittleEndianStructure):
    _pack_ = 1
    _fields_ = (
        ('index', ctypes.c_uint32),
        ('x', ctypes.c_float),
        ('y', c_half),
    )


unpacked = Triple.from_buffer_copy(packed)
print(unpacked.y.value)
Packed: 0400000033331340cd3c
1.2001953125

I am dissatisfied because, unlike with c_float, c_uint32 etc., there is no automatic coercion of the buffer data to the Python primitive (float and int respectively for those examples); I would expect float in this half-precision case.

Reading into the CPython source, the built-in types are subclasses of _SimpleCData:

static PyType_Spec pycsimple_spec = {
    .name = "_ctypes._SimpleCData",
    .flags = (Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE |
              Py_TPFLAGS_IMMUTABLETYPE),
    .slots = pycsimple_slots,
};

and only declare a _type_, for instance

class c_float(_SimpleCData):
    _type_ = "f"

However, attempting the naive

class c_half(ctypes._SimpleCData):
    _type_ = 'e'

results in

AttributeError: class must define a '_type_' attribute which must be
a single character string containing one of 'cbBhHiIlLdfuzZqQPXOv?g'.

as defined by SIMPLE_TYPE_CHARS:

static const char SIMPLE_TYPE_CHARS[] = "cbBhHiIlLdfuzZqQPXOv?g";
// ...
    if (!strchr(SIMPLE_TYPE_CHARS, *proto_str)) {
        PyErr_Format(PyExc_AttributeError,
                     "class must define a '_type_' attribute which must be\n"
                     "a single character string containing one of '%s'.",
                     SIMPLE_TYPE_CHARS);
        goto error;
    }

The end goal is to have a c_half type that I can use with the exact same API as the other built-in ctypes.c_ classes, ideally without myself writing a C module. I think I need to mimic much of the behaviour seen in the neighbourhood of PyCSimpleType_init but that code is difficult for me to follow.


Solution

  • Using descriptors gets close to what you want. Declare the ctypes fields with underscores and add the descriptors as class variables. If you are not familiar with descriptors, read the guide in the link above.

    import ctypes as ct
    import struct
    
    # Descriptor implementation
    
    class Half:
    
        def __set_name__(self, owner, name):
            self.field = f'_{name}'  # name of ctypes field
    
        def __get__(self, obj, objtype=None):
            # Translate ctypes c_half field to float
            data = getattr(obj, self.field)
            return struct.unpack('e', data)[0]
    
        def __set__(self, obj, value):
            # Translate float to ctypes c_half field
            setattr(obj, self.field, c_half(*struct.pack('e', value)))
    
    # two-byte field with display overrides
    
    class c_half(ct.c_ubyte*2):
    
        def __repr__(self):
            return f'c_half({self})'
    
        def __str__(self):
            return str(struct.unpack('e', bytes(self))[0])
    
    class Quad(ct.Structure):
    
        y = Half()  # Declare descriptors
        z = Half()  # Descriptor name must match _name of ctypes field
    
        _pack_ = 1
        _fields_ = (('index', ct.c_uint32),
                    ('x', ct.c_float),
                    ('_y', c_half),  # ctypes fields
                    ('_z', c_half))
    
        # Only needed if you want Quad(1,2,3,4) construction.
        # Without it, Quad() initializes all fields to zero
        # and must set them manually.  You can do Quad(1,2)
        # to set the non-c_half fields (index and x).
        def __init__(self, index=0, x=0, y=0, z=0):
            self.index = index
            self.x = x
            self.y =  y  # needed to call descriptor __set__
            self.z =  z  # needed to call descriptor __set__
    
        def __repr__(self):
            return f'Quad(index={self.index}, x={self.x}, y={self.y}, z={self.z})'
    
    # Examples
    t = Quad(4, 1.2, 2.3, 3.4)
    print(t)
    print(f'{t.y=} {repr(t._y)}, {t.z=} {repr(t._z)}')
    t.y, t.z = 8.8, 9.9
    print(f'{t.y=} {repr(t._y)}, {t.z=} {repr(t._z)}')
    

    Output:

    Quad(index=4, x=1.2000000476837158, y=2.30078125, z=3.400390625)
    t.y=2.30078125 c_half(2.30078125), t.z=3.400390625 c_half(3.400390625)
    t.y=8.796875 c_half(8.796875), t.z=9.8984375 c_half(9.8984375)