Search code examples
pythonpython-3.xinheritancepython-dataclasses

Using __new__ in inherited dataclasses


Suppose I have the following code that is used to handle links between individuals and countries:

from dataclasses import dataclass

@dataclass
class Country:
    iso2 : str
    iso3 : str
    name : str

countries = [ Country('AW','ABW','Aruba'),
              Country('AF','AFG','Afghanistan'),
              Country('AO','AGO','Angola')]
countries_by_iso2 = {c.iso2 : c for c in countries}
countries_by_iso3 = {c.iso3 : c for c in countries}

@dataclass
class CountryLink:
    person_id : int
    country : Country

country_links = [ CountryLink(123, countries_by_iso2['AW']),
                  CountryLink(456, countries_by_iso3['AFG']),
                  CountryLink(789, countries_by_iso2['AO'])]

print(country_links[0].country.name)

This is all working fine, but I decide that I want to make it a bit less clunky to be able to handle the different forms of input. I also want to use __new__ to make sure that we are getting a valid ISO code each time, and I want to object to fail to be created in that case. I therefore add a couple new classes that inherit from this:

@dataclass
class CountryLinkFromISO2(CountryLink):
    def __new__(cls, person_id : int, iso2 : str):
        if iso2 not in countries_by_iso2:
            return None
        new_obj = super().__new__(cls)
        new_obj.country = countries_by_iso2[iso2]
        return new_obj

@dataclass
class CountryLinkFromISO3(CountryLink):
    def __new__(cls, person_id : int, iso3 : str):
        if iso3 not in countries_by_iso3:
            return None
        new_obj = super().__new__(cls)
        new_obj.country = countries_by_iso3[iso3]
        return new_obj

country_links = [ CountryLinkFromISO2(123, 'AW'),
                  CountryLinkFromISO3(456, 'AFG'),
                  CountryLinkFromISO2(789, 'AO')]

This appears to work at first glance, but then I run into a problem:

a = CountryLinkFromISO2(123, 'AW')
print(type(a))
print(a.country)
print(type(a.country))

returns:

<class '__main__.CountryLinkFromISO2'>
AW
<class 'str'>

The inherited object has the right type, but its attribute country is just a string instead of the Country type that I expect. I have put in print statements in the __new__ that check the type of new_obj.country, and it is correct before the return line.

What I want to achieve is to have a be an object of the type CountryLinkFromISO2 that will inherit changes I make to CountryLink and for it to have an attribute country that is taken from the dictionary countries_by_iso2. How can I achieve this?


Solution

  • Just because the dataclass does it behind the scenes, doesn't mean you classes don't have an __init__(). They do and it looks like:

    def __init__(self, person_id: int, country: Country):
        self.person_id = person_id
        self.country = country
    

    When you create the class with:

    CountryLinkFromISO2(123, 'AW')
    

    that "AW" string gets passed to __init__() and sets the value to a string.

    Using __new__() in this way is fragile and returning None from a constructor is fairly un-pythonic (imo). Maybe you would be better off making an actual factory function that returns either None or the class you want. Then you don't need to mess with __new__() at all.

    @dataclass
    class CountryLinkFromISO2(CountryLink):
        @classmethod
        def from_country_code(cls, person_id : int, iso2 : str):
            if iso2 not in countries_by_iso2:
                return None
            return cls(person_id, countries_by_iso2[iso2])
    
    a = CountryLinkFromISO2.from_country_code(123, 'AW')
    

    If for some reason it needs to work with __new__(), you could return None from new when there's no match, and set the country in __post_init__():

    @dataclass
    class CountryLinkFromISO2(CountryLink):
        def __new__(cls, person_id : int, iso2 : str):
            if iso2 not in countries_by_iso2:
                return None
            return super().__new__(cls)
        
        def __post_init__(self):        
            self.country = countries_by_iso2[self.country]