2

I would like to wrap an abstractmethod for all subclasses of an abc. I tried doing that by implementing __init_subclass__ as below:

import abc


class Base(abc.ABC):

  @abc.abstractmethod
  def foo(self) -> str:
    pass

  def __init_subclass__(cls):
    super().__init_subclass__()
    orig_foo = cls.foo
    cls.foo = lambda s: orig_foo(s) + 'def'


class Derived(Base):

  def foo(self):
    return 'abc'

This works and if I do something like:

derived = Derived()
derived.foo()  # -> 'abcdef'

which is expected. Unfortunately, I noticed that this approach does not invoke the abc check, so if I forget to implement foo on Derived:

class Derived(Base):
  pass

I can still create it:

derived = Derived()  # This works
derived.foo()        # -> TypeError: unsupported operand type(s) for +: 'NoneType' and 'str'

Is there a way to do the above wrapping, but not break abc.abstractmethod checks?

tepsijash
  • 370
  • 2
  • 18
  • A None check on `orig_foo` ? – Kris Apr 11 '22 at 13:09
  • Seems like orig_foo is not None, but evaluates to `Base.foo` if we forget to implement it in `Derived`. Just posted an answer that addresses it. Thanks for the lead! – tepsijash Apr 11 '22 at 16:14

2 Answers2

1

I think a better techinque would be to declare a separate non-abstract method that makes use of the abstract method.

import abc

class Base(abc.ABC):

  @abc.abstractmethod
  def foo_body(self) -> str:
      pass

  def foo(self) -> str:
    return self.foo_body() + 'def'


class Derived(Base):

  def foo_body(self):
    return 'abc'

Concrete subclasses are only responsible for overriding foo_body; Base.foo itself need not be touched.

chepner
  • 497,756
  • 71
  • 530
  • 681
  • Definitely agree that this is cleaner. For whatever reason, I am trying to do this wrapping to improve efficiency of an existing library, so this would be a breaking change. – tepsijash Apr 12 '22 at 12:57
0

I was able to solve it by overriding cls.foo in __init_subclass__ only if cls.foo differs from Base.foo (i.e. is already overridden).

This yields the following Base implementation:

class Base(abc.ABC):

  @abc.abstractmethod
  def foo(self) -> str:
    pass

  def __init_subclass__(cls):
    super().__init_subclass__()
    orig_foo = cls.foo
    if orig_foo != Base.foo:  # Already overridden.
      cls.foo = lambda s: orig_foo(s) + 'def'
tepsijash
  • 370
  • 2
  • 18