1

I have a base Python (3.8) abstract base class, with two classes inheriting from it:

BoundedModel = TypeVar("BoundedModel", bound=CustomBaseModel)

class BaseDataStore(ABC, Generic[BoundedModel]):
    def __init__(self, resource_name: str) -> None:
        self.client = client(resource_name)

    @abstractmethod
    def get_all(self) -> List[BoundedModel]:
        pass

class MetadataStore(BaseDataStore[Metadata]):
    def get_all(self) -> List[Metadata]:
        items = self.client.get_all()
        return [Metadata(**item) for item in items]
    
class TranscriptStore(BaseDataStore[Transcript]):
    def get_all(self) -> List[Transcript]:
        items = self.client.get_all()
        return [Transcript(**item) for item in items]

The CustomBaseModel bound for BoundedModel represents a pydantic class, meaning that Metadata and Transcript are pydantic class models used for validation.

The concrete implementations of get_all all do the exact same thing: they validate the data with the Pydantic bounded model. This works, but forces me to spell out the concrete implementation for each BaseDataStore child.

Is there any way that I could implement get_all as a generic method (rather than abstract) in the parent BaseDataStore, therefore removing the need for concrete implementations in the children?

alexcs
  • 480
  • 5
  • 16
  • 1
    what does the `__init__` method of these classes look like? – Anentropic Mar 02 '23 at 10:13
  • It doesn't do anything complex, simply inits the client. Why do you ask? Thanks for contributing a solution - I'm still evaluating the different options. – alexcs Mar 03 '23 at 09:27

2 Answers2

1

You can avoid re-implementing the method for each sub-class by using a class var to store the type used to instantiate items, which can be derived directly from the Generic param type.

Like this:

from abc import ABC
from typing import Generic, TypeVar, Type


class CustomBaseModel:
    pass

class Metadata(CustomBaseModel):
    pass

class Transcript(CustomBaseModel):
    pass


class Client:
    def get_all(self) -> list[dict]:
        return [{}]


BoundedModel = TypeVar("BoundedModel", bound=CustomBaseModel)


class BaseDataStore(ABC, Generic[BoundedModel]):
    _item_cls: Type[BoundedModel]
    
    client = Client()
    
    def get_all(self) -> list[BoundedModel]:
        items = self.client.get_all()
        return [self._item_cls(**item) for item in items]

class MetadataStore(BaseDataStore[Metadata]):
    pass

class TranscriptStore(BaseDataStore[Transcript]):
    pass


metadata_items = MetadataStore().get_all()
# metadata_items: list[Metadata]

This type-checks:
https://mypy-play.net/?mypy=latest&python=3.11&gist=4f50432739f25ec6ca444e787c8ee0eb

...but unfortunately it doesn't actually work in practice yet, because no value is assigned to _item_cls at runtime.

We can work around that with an additional bit of metaprogramming...

from abc import ABCMeta
from typing import Generic, TypeVar, Type, get_args


class CustomBaseModel:
    pass

class Metadata(CustomBaseModel):
    pass

class Transcript(CustomBaseModel):
    pass


class Client:
    def get_all(self) -> list[dict]:
        return [{}]


BoundedModel = TypeVar("BoundedModel", bound=CustomBaseModel)


class GenericDataStoreMetaclass(ABCMeta):
    def __new__(cls, name, bases, dct):
        cls_ = super().__new__(cls, name, bases, dct)
        for base, og_base in zip(cls_.__bases__, cls_.__orig_bases__):
            if base is BaseDataStore:
                # introspect the type param of the Generic alias
                cls_._item_cls = get_args(og_base)[0]
        return cls_

class BaseDataStore(Generic[BoundedModel], metaclass=GenericDataStoreMetaclass):
    _item_cls: Type[BoundedModel]

    client = Client()
    
    def get_all(self) -> list[BoundedModel]:
        items = self.client.get_all()
        return [self._item_cls(**item) for item in items]

class MetadataStore(BaseDataStore[Metadata]):
    pass

class TranscriptStore(BaseDataStore[Transcript]):
    pass


metadata_items = MetadataStore().get_all()
# [<__main__.Metadata at 0x108493520>]

This version now works at runtime.

Anentropic
  • 32,188
  • 12
  • 99
  • 147
1

You can actually.

Use this trick via __orig_bases__ to access the type argument provided to a specific subclass. Then a single concrete implementation on BaseDataStore is enough and you will not even need to repeat the type argument anywhere in the subclasses.

Suppose you have the following models:

from pydantic import BaseModel


class CustomBaseModel(BaseModel):
    pass


class Foo(CustomBaseModel):
    x: int


class Bar(CustomBaseModel):
    y: str

Here is the solution I propose:

from typing import Any, Dict, Generic, List, Optional, Type, TypeVar
from typing import get_args, get_origin

BoundedModel = TypeVar("BoundedModel", bound=CustomBaseModel)

class BaseDataStore(Generic[BoundedModel]):
    _type_arg: Optional[Type[BoundedModel]] = None

    @classmethod
    def __init_subclass__(cls, **kwargs: Any) -> None:
        """Saves the type argument in the `_type_arg` class attribute."""
        super().__init_subclass__(**kwargs)
        for base in cls.__orig_bases__:  # type: ignore[attr-defined]
            origin = get_origin(base)
            if origin is None or not issubclass(origin, BaseDataStore):
                continue
            type_arg = get_args(base)[0]
            # Do not set the attribute for GENERIC subclasses!
            if not isinstance(type_arg, TypeVar):
                cls._type_arg = type_arg
                return

    @classmethod
    def get_model(cls) -> Type[BoundedModel]:
        if cls._type_arg is None:
            raise AttributeError(f"{cls.__name__} is generic; type argument unspecified")
        return cls._type_arg

    def get_all(self) -> List[BoundedModel]:
        items = self.demo_data  # just for this example
        return [self.get_model()(**item) for item in items]

    demo_data: List[Dict[str, Any]]  # just for this example

Usage:

class FooStore(BaseDataStore[Foo]):
    demo_data = [{"x": 1}, {"x": -1}]


class BarStore(BaseDataStore[Bar]):
    demo_data = [{"y": "spam"}, {"y": "eggs"}]


foos = FooStore().get_all()
bars = BarStore().get_all()

print(foos)
print(bars)

Output:

[Foo(x=1), Foo(x=-1)]
[Bar(y='spam'), Bar(y='eggs')]

Passes mypy --strict. No metaclass magic required.

Daniil Fajnberg
  • 12,753
  • 2
  • 10
  • 41