6

I have this code:

from dataclasses import dataclass
from typing import List

@dataclass
class Position:
    name: str
    lon: float
    lat: float

@dataclass
class Section:
    positions: List[Position]

pos1 = Position('a', 52, 10)
pos2 = Position('b', 46, -10)
pos3 = Position('c', 45, -10)

sec = Section([pos1, pos2 , pos3])

print(sec.positions)

How can I create additional attributes in the dataclass Section so they would be a list of the attribute of its subclass Position?

In my example, I would like that the section object also returns:

sec.name = ['a', 'b', 'c']   #[pos1.name,pos2.name,pos3.name]
sec.lon = [52, 46, 45]       #[pos1.lon,pos2.lon,pos3.lon]
sec.lat = [10, -10, -10]     #[pos1.lat,pos2.lat,pos3.lat]

I tried to define the dataclass as:

@dataclass
class Section:
    positions: List[Position]
    names :  List[Position.name]

But it is not working because name is not an attribute of position. I can define the object attributed later in the code (e.g. by doing secs.name = [x.name for x in section.positions]). But it would be nicer if it can be done at the dataclass definition level.

After posting this question I found a beginning of answer (https://stackoverflow.com/a/65222586/13890678).

But I was wondering if there was not a more generic/"automatic" way of defining the Section methods : .names(), .lons(), .lats(), ... ? So the developer doesn't have to define each method individually but instead, these methods are created based on the Positions object attributes?

lhoupert
  • 584
  • 8
  • 25
  • Interestingly, this is similar to how `pandas.DataFrame` is implemented: each columm is a `pandas.Series` object and can be accessed as an attribute (correct me if anyone knows otherwise). – Leonardus Chen Dec 17 '20 at 07:40
  • You can mess around with the dataclass creation to make something convenient happen. I'll try to come up with an answer, but it's going to be full of meta-programming. – Arne Dec 18 '20 at 23:38

3 Answers3

6

You could create a new field after __init__ was called:

from dataclasses import dataclass, field, fields
from typing import List


@dataclass
class Position:
    name: str
    lon: float
    lat: float


@dataclass
class Section:
    positions: List[Position]
    _pos: dict = field(init=False, repr=False)

    def __post_init__(self):
        # create _pos after init is done, read only!
        Section._pos = property(Section._get_positions)

    def _get_positions(self):
        _pos = {}

        # iterate over all fields and add to _pos
        for field in [f.name for f in fields(self.positions[0])]:
            if field not in _pos:
                _pos[field] = []

            for p in self.positions:
                _pos[field].append(getattr(p, field))
        return _pos


pos1 = Position('a', 52, 10)
pos2 = Position('b', 46, -10)
pos3 = Position('c', 45, -10)

sec = Section([pos1, pos2, pos3])

print(sec.positions)
print(sec._pos['name'])
print(sec._pos['lon'])
print(sec._pos['lat'])

Out:

[Position(name='a', lon=52, lat=10), Position(name='b', lon=46, lat=-10), Position(name='c', lon=45, lat=-10)]
['a', 'b', 'c']
[52, 46, 45]
[10, -10, -10]

Edit:

In case you just need it more generic, you could overwrite __getattr__:

from dataclasses import dataclass, field, fields
from typing import List


@dataclass
class Position:
    name: str
    lon: float
    lat: float


@dataclass
class Section:
    positions: List[Position]

    def __getattr__(self, keyName):
        for f in fields(self.positions[0]):
            if f"{f.name}s" == keyName:
                return [getattr(x, f.name) for x in self.positions]
        # Error handling here: Return empty list, raise AttributeError, ...

pos1 = Position('a', 52, 10)
pos2 = Position('b', 46, -10)
pos3 = Position('c', 45, -10)

sec = Section([pos1, pos2, pos3])

print(sec.names)
print(sec.lons)
print(sec.lats)

Out:

['a', 'b', 'c']
[52, 46, 45]
[10, -10, -10]
Maurice Meyer
  • 17,279
  • 4
  • 30
  • 47
  • Thank you very much for thee quick reply! It is indeed a good solution. The only probleme I see is if onee of the pos* object is changed after the section objeect is created- the section object will be out of sync. Is there a way to prevent that? – lhoupert Dec 09 '20 at 18:10
1

After some more thinking I thought an alternative solution using methods:


from dataclasses import dataclass
from typing import List

@dataclass
class Position:
    name: str
    lon: float
    lat: float

@dataclass
class Section:
    positions: List[Position]

    def names(self):
        return [x.name for x in self.positions]

    def lons(self):
        return [x.lon for x in self.positions]

    def lats(self):
        return [x.lat for x in self.positions]


pos1 = Position('a', 52, 10)
pos2 = Position('b', 46, -10)
pos3 = Position('c', 45, -10)

sec = Section([pos1, pos2 , pos3])

print(sec.positions)
print(sec.names())
print(sec.lons())
print(sec.lats())

But I was wondering if there was not a more generic/"automatic" way of defining the Section methods : .names(), .lons(), .lats(), ... ? So the developer doesn't have to define each method individually but instead, these methods are created based on the Positions object attributes?

lhoupert
  • 584
  • 8
  • 25
1

The way I understood you, you'd like to declare dataclasses that are flat data containers (like Position), which are nested into a container of another dataclass (like Section). The outer dataclass should then be able to access a list of all the attributes of its inner dataclass(es) through simple name access.

We can implement this kind of functionality (calling it, for example, introspect) on top of how a regular dataclass works, and can enable it on demand, similar to the already existing flags:

from dataclasses import is_dataclass, fields, dataclass as dc

# existing dataclass siganture, plus "instrospection" keyword
def dataclass(_cls=None, *, init=True, repr=True, eq=True, order=False,
              unsafe_hash=False, frozen=False, introspect=False):

    def wrap(cls):
        # run original dataclass decorator
        dc(cls, init=init, repr=repr, eq=eq, order=order,
           unsafe_hash=unsafe_hash, frozen=frozen)

        # add our custom "introspect" logic on top
        if introspect:
            for field in fields(cls):
                # only consider nested dataclass in containers
                try:
                    name = field.type._name
                except AttributeError:
                    continue
                if name not in ("List", "Set", "Tuple"):
                    continue
                contained_dc = field.type.__args__[0]
                if not is_dataclass(contained_dc):
                    continue
                # once we got them, add their fields as properties
                for dc_field in fields(contained_dc):
                    # if there are name-clashes, use f"{field.name}_{dc_field.name}" instead
                    property_name = dc_field.name
                    # bind variables explicitly to avoid funny bugs
                    def magic_property(self, field=field, dc_field=dc_field):
                        return [getattr(attr, dc_field.name) for attr in getattr(self, field.name)]
                    # here is where the magic happens
                    setattr(
                        cls,
                        property_name,
                        property(magic_property)
                    )
        return cls

    # Handle being called with or without parens
    if _cls is None:
        return wrap
    return wrap(_cls)

The resulting dataclass-function can now be used in the following way:

# regular dataclass
@dataclass
class Position:
    name: str
    lon: float
    lat: float
    
# this one will introspect its fields and try to add magic properties
@dataclass(introspect=True)
class Section:
    positions: List[Position]

And that's it. The properties get added during class construction, and will even update accordingly if any of the objects changes during its lifetime:

>>> p_1 = Position("1", 1.0, 1.1)
>>> p_2 = Position("2", 2.0, 2.1)
>>> p_3 = Position("3", 3.0, 3.1)
>>> section = Section([p_1 , p_2, p_3])
>>> section.name
['1', '2', '3']
>>> section.lon
[1.0, 2.0, 3.0]
>>> p_1.lon = 5.0
>>> section.lon
[5.0, 2.0, 3.0]
Arne
  • 17,706
  • 5
  • 83
  • 99