7

Suppose I have the following code that is used to handle links between individuals and countries:

from dataclasses import dataclass

@dataclass
class Country:
    iso2 : str
    iso3 : str
    name : str

countries = [ Country('AW','ABW','Aruba'),
              Country('AF','AFG','Afghanistan'),
              Country('AO','AGO','Angola')]
countries_by_iso2 = {c.iso2 : c for c in countries}
countries_by_iso3 = {c.iso3 : c for c in countries}

@dataclass
class CountryLink:
    person_id : int
    country : Country

country_links = [ CountryLink(123, countries_by_iso2['AW']),
                  CountryLink(456, countries_by_iso3['AFG']),
                  CountryLink(789, countries_by_iso2['AO'])]

print(country_links[0].country.name)

This is all working fine, but I decide that I want to make it a bit less clunky to be able to handle the different forms of input. I also want to use __new__ to make sure that we are getting a valid ISO code each time, and I want to object to fail to be created in that case. I therefore add a couple new classes that inherit from this:

@dataclass
class CountryLinkFromISO2(CountryLink):
    def __new__(cls, person_id : int, iso2 : str):
        if iso2 not in countries_by_iso2:
            return None
        new_obj = super().__new__(cls)
        new_obj.country = countries_by_iso2[iso2]
        return new_obj

@dataclass
class CountryLinkFromISO3(CountryLink):
    def __new__(cls, person_id : int, iso3 : str):
        if iso3 not in countries_by_iso3:
            return None
        new_obj = super().__new__(cls)
        new_obj.country = countries_by_iso3[iso3]
        return new_obj

country_links = [ CountryLinkFromISO2(123, 'AW'),
                  CountryLinkFromISO3(456, 'AFG'),
                  CountryLinkFromISO2(789, 'AO')]

This appears to work at first glance, but then I run into a problem:

a = CountryLinkFromISO2(123, 'AW')
print(type(a))
print(a.country)
print(type(a.country))

returns:

<class '__main__.CountryLinkFromISO2'>
AW
<class 'str'>

The inherited object has the right type, but its attribute country is just a string instead of the Country type that I expect. I have put in print statements in the __new__ that check the type of new_obj.country, and it is correct before the return line.

What I want to achieve is to have a be an object of the type CountryLinkFromISO2 that will inherit changes I make to CountryLink and for it to have an attribute country that is taken from the dictionary countries_by_iso2. How can I achieve this?

EdG
  • 137
  • 1
  • 10
  • Are you sure you want to be overriding `__new__` and not `__init__`? You might also consider using a library like [attrs](https://www.attrs.org/en/stable/). – Nathaniel Ford Aug 08 '21 at 19:05
  • @NathanielFord `__init__` would always return an instance of the class, which I do not want to happen if the input is invalid. I could have an `__init__` that raises an exception, but that means that every time I try to call my code I have to put it in a try/except block, which is clunky and can cause performance issues. – EdG Aug 08 '21 at 19:16
  • I think Mark provided the correct course (factory method), but you should check out attrs or similar libraries for their validators. – Nathaniel Ford Aug 09 '21 at 17:44

2 Answers2

8

Just because the dataclass does it behind the scenes, doesn't mean you classes don't have an __init__(). They do and it looks like:

def __init__(self, person_id: int, country: Country):
    self.person_id = person_id
    self.country = country

When you create the class with:

CountryLinkFromISO2(123, 'AW')

that "AW" string gets passed to __init__() and sets the value to a string.

Using __new__() in this way is fragile and returning None from a constructor is fairly un-pythonic (imo). Maybe you would be better off making an actual factory function that returns either None or the class you want. Then you don't need to mess with __new__() at all.

@dataclass
class CountryLinkFromISO2(CountryLink):
    @classmethod
    def from_country_code(cls, person_id : int, iso2 : str):
        if iso2 not in countries_by_iso2:
            return None
        return cls(person_id, countries_by_iso2[iso2])

a = CountryLinkFromISO2.from_country_code(123, 'AW')

If for some reason it needs to work with __new__(), you could return None from new when there's no match, and set the country in __post_init__():

@dataclass
class CountryLinkFromISO2(CountryLink):
    def __new__(cls, person_id : int, iso2 : str):
        if iso2 not in countries_by_iso2:
            return None
        return super().__new__(cls)
    
    def __post_init__(self):        
        self.country = countries_by_iso2[self.country]
Mark
  • 90,562
  • 7
  • 108
  • 148
  • Does this mean that the `__init__` created implicitly by the dataclass decorator is running after the `__new__` with the same arguments positionally, and therefore overwriting the previous value of `country`? Returning `None` from the constructor may by unpythonic, but the best alternative would be try/except blocks, which I understand can cause performance problems. – EdG Aug 08 '21 at 19:27
  • 1
    Yes `__init__()` is created by the dataclass decorator. It's [documented here](https://docs.python.org/3/library/dataclasses.html) and is one of the reasons for using dataclasses. And `__init__()` is called after `__new__()`. That's not specific to dataclasses. – Mark Aug 08 '21 at 20:00
  • So if I want to avoid the `__init__` replacing the value of `country` that was set in the `__new__`, would I need to manually specify both the `__init__` and the `__new__` so it would take the same positional arguments but deliberately not overwrite? – EdG Aug 08 '21 at 20:19
3

The behaviour you see is because dataclasses set their fields in __init__, which happens after __new__ has run.

The Pythonic way to solve this would be to provide an alternate constructor. I would not do the subclasses, as they are only used for their constructor.

For example:

@dataclass
class CountryLink:
    person_id: int
    country: Country

    @classmethod
    def from_iso2(cls, person_id: int, country_code: str):
        try:
            return cls(person_id, countries_by_iso2[country_code])
        except KeyError:
            raise ValueError(f'invalid ISO2 country code {country_code!r}') from None

    @classmethod
    def from_iso3(cls, person_id: int, country_code: str):
        try:
            return cls(person_id, countries_by_iso3[country_code])
        except KeyError:
            raise ValueError(f'invalid ISO3 country code {country_code!r}') from None

country_links = [ CountryLink.from_iso2(123, 'AW'),
                  CountryLink.from_iso3(456, 'AFG'),
                  CountryLink.from_iso2(789, 'AO')]
Jasmijn
  • 9,370
  • 2
  • 29
  • 43
  • This method would mean that I have to use a try/except block every time I create my class. Wouldn't that have performance implications? – EdG Aug 08 '21 at 19:19
  • Try/except blocks have a minimal performance impact if no exception is raised. It might even be _faster_ because there is no `in` check and if-branch. Either way, I don't think this is likely to be the bottleneck in your code. – Jasmijn Aug 08 '21 at 19:30
  • 1
    one could argue that forcing the check for invalid data is a _good_ thing – Matthew Purdon Apr 28 '22 at 19:31