Search code examples
pythonpython-dataclassespython-descriptors

What is the proper way to use descriptors as fields in Python dataclasses?


I've been playing around with python dataclasses and was wondering: What is the most elegant or most pythonic way to make one or some of the fields descriptors?

In the example below I define a Vector2D class that should be compared on its length.

from dataclasses import dataclass, field
from math import sqrt

@dataclass(order=True)
class Vector2D:
    x: int = field(compare=False)
    y: int = field(compare=False)
    length: int = field(init=False)
    
    def __post_init__(self):
        type(self).length = property(lambda s: sqrt(s.x**2+s.y**2))

Vector2D(3,4) > Vector2D(4,1) # True

While this code works, it touches the class every time an instance is created, is there a more readable / less hacky / more intended way to use dataclasses and descriptors together?

Just having length as a property and not a field will work but this means I have to write __lt__, et.al. by myself.

Another solution I found is equally unappealing:

@dataclass(order=True)
class Vector2D:
    x: int = field(compare=False)
    y: int = field(compare=False)
    length: int = field(init=False)
    
    @property
    def length(self):
        return sqrt(self.x**2+self.y**2)
    
    @length.setter
    def length(self, value):
        pass

Introducing a no-op setter is necessary as apparently the dataclass-created init method tries to assign to length even though there isn't a default value and it explicitly sets init=False...

Surely there has to be a better way right?


Solution

  • Might not answer your exact question, but you mentioned that the reason that you didnt want to have length as a property and a not field was because you would have to

    write __lt__, et.al by myself

    While you do have to implement __lt__ by yourself, you can actually get away with implementing just that

    from functools import total_ordering
    from dataclasses import dataclass, field
    from math import sqrt
    
    @total_ordering
    @dataclass
    class Vector2D:
        x: int
        y: int
    
        @property
        def length(self):
            return sqrt(self.x ** 2 + self.y ** 2)
    
        def __lt__(self, other):
            if not isinstance(other, Vector2D):
                return NotImplemented
    
            return self.length < other.length
    
        def __eq__(self, other):
            if not isinstance(other, Vector2D):
                return NotImplemented
    
            return self.length == other.length
    
    
    print(Vector2D(3, 4) > Vector2D(4, 1))
    

    The reason this works is because total_ordering just adds all the other equality methods based on __eq__ and __lt__