1

I am using Python 3.11 and I would need to detect if an optional class attribute is type of Enum (i.e. type of a subclass of Enum).

With typing.get_type_hints() I can get the type hints as a dict, but how to check if a field's type is optional Enum (subclass)? Even better if I could get the type of any optional field regardless is it Optional[str], Optional[int], Optional[Class_X], etc.

Example code

from typing import Optional, get_type_hints
from enum import IntEnum, Enum

class TestEnum(IntEnum):
    foo = 1
    bar = 2


class Foo():
    opt_enum : TestEnum | None = None

types = get_type_hints(Foo)['opt_enum']

This works

(ipython)

In [4]: Optional[TestEnum] == types
Out[4]: True

These ones fail

(yes, these are desperate attempts)

In [6]: Optional[IntEnum] == types
Out[6]: False

and

In [11]: issubclass(Enum, types)
Out[11]: False

and

In [12]: issubclass(types, Enum)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In [12], line 1
----> 1 issubclass(types, Enum)

TypeError: issubclass() arg 1 must be a class

and

In [13]: issubclass(types, Optional[Enum])
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In [13], line 1
----> 1 issubclass(types, Optional[Enum])

File /usr/lib/python3.10/typing.py:1264, in _UnionGenericAlias.__subclasscheck__(self, cls)
   1262 def __subclasscheck__(self, cls):
   1263     for arg in self.__args__:
-> 1264         if issubclass(cls, arg):
   1265             return True

TypeError: issubclass() arg 1 must be a class

and

In [7]: IntEnum in types
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In [7], line 1
----> 1 IntEnum in types

TypeError: argument of type 'types.UnionType' is not iterable

Why I needed this

I have several cases where I am importing data from csv files and creating objects of a class from each row. csv.DictReader() returns a dict[str, str] and I need to fix the types for the fields before attempting to create the object. However, some of the object fields are Optional[int], Optional[bool], Optional[EnumX] or Optional[ClassX]. I have several of those classes multi-inheriting my CSVImportable() class/interface. I want to implement the logic once into CSVImportable() class instead of writing roughly same code in field-aware way in every subclass. This CSVImportable._field_type_updater() should:

  1. correctly change the types at least for basic types and enums
  2. Gracefully skip Optional[ClassX] fields

Naturally I am thankful for better designs too :-)

Jylpah
  • 223
  • 2
  • 7

1 Answers1

0

When you are dealing with a parameterized type (generic or special like typing.Optional), you can inspect it via get_args/get_origin.

Doing that you'll see that T | S is implemented slightly differently than typing.Union[T, S]. The origin of the former is types.UnionType, while that of the latter is typing.Union. Unfortunately this means that to cover both variants, we need two distinct checks.

from types import UnionType
from typing import Union, get_origin

def is_union(t: object) -> bool:
    origin = get_origin(t)
    return origin is Union or origin is UnionType

Using typing.Optional just uses typing.Union under the hood, so the origin is the same. Here is a working demo:

from enum import IntEnum
from types import UnionType
from typing import Optional, get_type_hints, get_args, get_origin, Union


class TestEnum(IntEnum):
    foo = 1
    bar = 2


class Foo:
    opt_enum1: TestEnum | None = None
    opt_enum2: Optional[TestEnum] = None
    opt_enum3: TestEnum
    opt4: str


def is_union(t: object) -> bool:
    origin = get_origin(t)
    return origin is Union or origin is UnionType


if __name__ == "__main__":
    for name, type_ in get_type_hints(Foo).items():
        if type_ is TestEnum or is_union(type_) and TestEnum in get_args(type_):
            print(name, "accepts TestEnum")

Output:

opt_enum1 accepts TestEnum
opt_enum2 accepts TestEnum
opt_enum3 accepts TestEnum
Daniil Fajnberg
  • 12,753
  • 2
  • 10
  • 41
  • Thanks, this got me forward how to deal with UnionType, but it still requires me to specify the particular Enum subclass instead of being able to check is it Enum or its subclass. But it seems I would need to iterate over members of `get_args()` and test `issubclass()`. – Jylpah Nov 23 '22 at 18:39