Search code examples
pythonpython-3.xpython-dataclasses

How to call class methods during instantiation of frozen dataclasses?


I have a wrapper that adds some methods to a bunch of dataclasses, the dataclasses are all meant to be frozen, and should check the data types of the initialization values during instantiation.

And I also wanted to overload the constructor of the dataclasses, so that they can be instantiated like these Foo(fields), Foo(*fields), Foo(mapping) and Foo(**mapping), but NOT Foo(*args, **kwargs). All the lengths of the arguments passed are equal, the first is a sequence containing the values of the fields in order, and the third is a mapping containing the key-value pairs of the fields, the constructor should accept either a list, or a dict, or an unpacked list, or an unpacked dict, but not a mixture of those. (i.e. not like Foo(a, [b, c]) or Foo(a, b, c, d=1)).

Refer to this question for more context. As you see I have succeeded in doing so in my hand-rolled class, but it is unpythonic.

This is a minimal reproducible example:

from dataclasses import dataclass, fields, asdict
from datetime import datetime
from typing import Union

SENTINEL = object()

def wrapper(cls):
    cls._name = [f.name for f in fields(cls)]
    @classmethod
    def from_sequence(cls, sequence):        
        for arg, field in zip(sequence, fields(cls)):
            if not isinstance(arg, field.type):
                raise TypeError(f"'{arg}' not of type '{field.type}'.")
        
        return cls(*sequence)
    
    cls.from_sequence = from_sequence
    @classmethod
    def from_dict(cls, mapping):
        for field in fields(cls):
            value = mapping.get(field.name, SENTINEL)
            if value != SENTINEL and not isinstance(value, field.type):
                raise TypeError(f"Field ''{field.name}' value '{value}' not of type '{field.type}'.")
        
        return cls(**mapping)
    
    cls.from_dict = from_dict
    return cls

NoneType = type(None)

@wrapper
@dataclass(frozen=True)
class Person:
    Name: str
    Age: Union[int, float]
    Birthdate: datetime

It sort of works, but the constructor accepts *args or **kwargs but not packed arguments, nor does it do type checking:

In [2]: Person('Jane Smith', 23, datetime(2000, 1, 1))
Out[2]: Person(Name='Jane Smith', Age=23, Birthdate=datetime.datetime(2000, 1, 1, 0, 0))

In [3]: Person(None, None, None)
Out[3]: Person(Name=None, Age=None, Birthdate=None)

In [4]: Person.from_sequence(['Jane Smith', 23, datetime(2000, 1, 1)])
Out[4]: Person(Name='Jane Smith', Age=23, Birthdate=datetime.datetime(2000, 1, 1, 0, 0))

In [5]: Person.from_sequence([None]*3)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[5], line 1
----> 1 Person.from_sequence([None]*3)

Cell In[1], line 13, in wrapper.<locals>.from_sequence(cls, sequence)
     11 for arg, field in zip(sequence, fields(cls)):
     12     if not isinstance(arg, field.type):
---> 13         raise TypeError(f"'{arg}' not of type '{field.type}'.")
     15 return cls(*sequence)

TypeError: 'None' not of type '<class 'str'>'.

In [6]: Person([None]*3)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[6], line 1
----> 1 Person([None]*3)

TypeError: Person.__init__() missing 2 required positional arguments: 'Age' and 'Birthdate'

I tried to overload __new__ to use the corresponding class methods during instantiation:

@wrapper
@dataclass(frozen=True)
class Person:
    Name: str
    Age: Union[int, float]
    Birthdate: datetime
    def __new__(cls,  *args, **kwargs):
        if args:
            assert not kwargs
            data = args if len(args) != 1 else args[0]
            return cls.from_dict(data) if isinstance(data, dict) else cls.from_sequence(data)
        else:
            assert kwargs
            return cls.from_dict(kwargs)

But it doesn't work, it just crashes my interpreter without any exception being raised, I figured this is because circular reference, __new__ calls classmethod, and classmethod calls __new__ recursively, this goes on forever, so it crashed the interpreter. I tried to overload __init__ and the same thing happens.

How do I properly override instantiation of a frozen dataclass so that the proper classmethod is called to validate the datatypes and handle the arguments?


Solution

  • I finally made it working, though I don't know if it is Pythonic.

    In short I extracted the methods of setting the attributes out of class methods, and made it called by both __init__ and class methods, and I overrode the __repr__ function so that the object can be instantiated without being initialized first:

    from dataclasses import dataclass, fields, asdict, FrozenInstanceError
    from collections.abc import Sequence, Mapping
    from datetime import datetime
    from typing import Union
    
    SENTINEL = object()
    
    def wrapper(cls):
        cls._name = [f.name for f in fields(cls)]
        cls._initialized = False
        def _populate_from_sequence(self, sequence):
            if self._initialized:
                raise FrozenInstanceError('Object has already been populated')
            if isinstance(sequence, str) or not isinstance(sequence, Sequence):
                raise TypeError(f'argument sequence of type {type(sequence)!r} is not a Sequence')
    
            for arg, field in zip(sequence, fields(cls)):
                if not isinstance(arg, field.type):
                    raise TypeError(f"Field '{field.name}' value {arg!r} not of type {field.type!r}.")
                object.__setattr__(self, field.name, arg)
            object.__setattr__(self, '_initialized', True)
    
        cls._populate_from_sequence = _populate_from_sequence
    
        @classmethod
        def from_sequence(cls, sequence):
            instance = object.__new__(cls)
            instance._populate_from_sequence(sequence)
            return instance
    
        cls.from_sequence = from_sequence
    
        def _populate_from_mapping(self, mapping):
            if not isinstance(mapping, Mapping):
                raise TypeError(f'argument mapping of type {type(mapping)!r} is not a Mapping')
            self._populate_from_sequence((mapping.get(field.name) for field in fields(cls)))
    
        cls._populate_from_mapping = _populate_from_mapping
        @classmethod
        def from_dict(cls, mapping):
            instance = object.__new__(cls)
            instance._populate_from_mapping(mapping)
            return instance
    
        cls.from_dict = from_dict
        def __repr__(self):
            return (
                f'{cls.__name__}('
                + ', '.join(repr(getattr(self, name)) for name in cls._name)
                + ')'
                if self._initialized
                else object.__repr__(self)
            )
        
        cls.__repr__ = __repr__
        return cls
    
    NoneType = type(None)
    
    @wrapper
    @dataclass(frozen=True, order=True)
    class Person:
        Name: str
        Age: Union[int, float]
        Birthdate: datetime
        def __init__(self,  *args, **kwargs):
            if args:
                assert not kwargs
                data = args if len(args) != 1 else args[0]
                self._populate_from_mapping(data) if isinstance(data, dict) else self._populate_from_sequence(data)
            else:
                assert kwargs
                self._populate_from_mapping(kwargs)