Search code examples
pythonpython-3.xenumspython-dataclasses

How to get the index of a dataclass field


Say I have a simple dataclass instance

import dataclasses as dc

@dc.dataclass
class DCItem:
    name: str
    unit_price: float

item = DCItem('test', 11)

Now I want to determine the position (index) of instance attribute item.unit_price. How can I make it simple to use and performant? I thought about using a get method using dc.asdict

@dc.dataclass
class DCItem:
    name: str
    unit_price: float

    def get_index(self, name):
        return list(dc.asdict(self)).index(name)

item.get_index('unit_price')  # 1

But this has two drawbacks:

  1. It's not very performant, at least not for many instance attributes
  2. It looses the nice auto-completion feature of item.unit_price

Is there a solution that combines the features of a dataclass with that of IntEnum and enum.auto() without the above drawbacks?


Solution

  • If the class is not going to be changed at runtime, you can cache indexes in a class attribute as a dictionary.

    import dataclasses as dc
    
    @dc.dataclass
    class DCItem:
        name: str
        unit_price: float
        
        @classmethod
        def get_index(cls, name):
            if '_idx_mapping' not in cls.__dict__:
                flds = dc.fields(cls)
                cls._idx_mapping = {flds[idx].name: idx for idx in range(len(flds))}
            return cls._idx_mapping[name]
    
    
    >>> item = DCItem('test', 11)
    >>> item.get_index('unit_price')
    1
    

    Accessing dictionary should be fast - O(n) in the worst case.

    >>> from timeit import timeit
    >>> timeit("item.get_index('unit_price')", "from __main__ import item")
    0.21372696105390787
    

    For comparison, your solution is quite slow, as you mentioned:

    >>> timeit("item.get_index('unit_price')", "from __main__ import item")
    4.260601775022224
    

    Note: I haven't tested this class with inheritance.


    EDIT: Solving the second point makes the solution more complex. I've come up with the following using Python descriptors.

    import dataclasses as dc
    from typing import Any
    from collections import defaultdict
    
    
    class IndexedField:
        def __init__(self, a_type: type, value: Any, index: int):
            self._validate_type(a_type, value)  # This line can be removed when type checking is not required.
            self._a_type = a_type
            self._value = value
            self._index = index
    
        @staticmethod
        def _validate_type(a_type: type, value: Any):
            if not isinstance(value, a_type):
                raise TypeError(f"value is of type {type(value)} but {a_type} is expected")
    
        @property
        def a_type(self):  # read-only
            return self._a_type
    
        @property
        def index(self):  # read-only
            return self._index
    
        @property
        def value(self):
            return self._value
    
        @value.setter
        def value(self, new_value):
            self._validate_type(self._a_type, new_value)  # This line can be removed when type checking is not required.
            self._value = new_value
    
        def __repr__(self):
            return (f'{self.__class__.__name__}'
                    f'(a_type={self._a_type!r}, index={self._index!r}, value={self._value!r})')
    
    
    class IndexedFieldDescriptor:
        _class_last_index = defaultdict(int)
        _class_indexes = defaultdict(dict)
    
        def __init__(self, a_type) -> None:
            self._name = None
            self._type = a_type
    
        def __get__(self, instance, owner):
            if instance is None:
                return self
            return instance.__dict__[self._name]
    
        def __set_name__(self, owner, name):
            self._name = name
            self._class_indexes[owner.__name__][name] = self._class_last_index[owner.__name__]
            self._class_last_index[owner.__name__] += 1
    
        def __set__(self, instance, value):
            index = self._class_indexes[instance.__class__.__name__][self._name]
            instance.__dict__[self._name] = IndexedField(self._type, value, index)
    
    
    @dc.dataclass
    class DCItem:
        name: IndexedField = IndexedFieldDescriptor(str)
        unit_price: IndexedField = IndexedFieldDescriptor(float)
    
    
    item = DCItem('test', 11.0)
    print(item)
    print(f"* name field value: {item.name.value!r}, name field index: {item.name.index!r}, name field type: {item.name.a_type!r}")
    print(f"* unit_price field value: {item.unit_price.value!r}, unit_price field index: {item.unit_price.index!r}, unit_price field type: {item.unit_price.a_type!r}")
    
    from timeit import timeit
    print(f'* Index access time: {timeit("item.name.index", "from __main__ import item")}')
    print(f'* Value access time: {timeit("item.name.value", "from __main__ import item")}')
    

    Output:

    DCItem(name=IndexedField(a_type=<class 'str'>, index=0, value='test'), unit_price=IndexedField(a_type=<class 'float'>, index=1, value=11.0))
    * name field value: 'test', name field index: 0, name field type: <class 'str'>
    * unit_price field value: 11.0, unit_price field index: 1, unit_price field type: <class 'float'>
    * Index access time: 0.2253845389932394
    * Value access time: 0.2729280750500038