0

EDIT #1: I have updated the example with a "manual" if/else based solution as suggested to demonstrate the need for further automation.


How to efficiently dispatch functions (i.e. implement something like multimethods) where the target function is selected based on keyword argument names instead of types?

My use case for this is implementing multiple factory methods for dataclasses whose fields are inter-dependent and which can be initialized based on different subsets of these fields e.g.

Below is an example using recursion that works fine but requires lots of hand-written error-prone code and doesn't really scale to more complex cases.

def a_from_b_c(b, c):
    return b+c

def b_from_a_c(a, c):
    return a+c

def c_from_a_b(a, b):
    return a**b

@datalass
class foo(object):
    a: float
    b: float
    c: float

    @classmethod
    def init_from(cls, **kwargs):
        if "a" not in kwargs and all(k in kwargs for k in ("b", "c")):
            kwargs["a"] = a_from_b_c(kwargs["b"], kwargs["c"])
            cls.init_from(**kwargs)
        if "b" not in kwargs and all(k in kwargs for k in ("a", "c")):
            kwargs["b"] = b_from_a_c(kwargs["a"], kwargs["c"])
            cls.init_from(**kwargs)
        if "c" not in kwargs and all(k in kwargs for k in ("a", "b")):
            kwargs["c"] = c_from_a_b(kwargs["a"], kwargs["b"])
            cls.init_from(**kwargs)
        return cls(**kwargs)
        

I am searching for a solution that scales up to dataclasses with many fields and complex initialization paths while on the other hand requiring less hand-written code with lots of duplications and sources of errors.. the patterns in the code above are quite obvious and could be automated, but I want to be sure to use the right tools here.

simfinite
  • 59
  • 4

2 Answers2

0

After the recent edit, which is frankly quite a big change, you could do something like this:

import inspect
from dataclasses import dataclass
from collections import defaultdict


class Initializer:
    """This class collects all registered functions and allows
    multiple ways to calculate your field.
    """
    def __init__(self):
        self.mappings = defaultdict(list)
        
    def __call__(self, arg):
        def wrapper(func):
            self.mappings[arg].append(func)      
        return wrapper


# Create an instance and register your functions
init = Initializer()


# Add the `kwargs` for convenience
@init("a")
def a_from_b_c(b, c, **kwargs):
    return b + c


@init("a")
def a_from_b_d(b, d, **kwargs):
    return b + d


@init("b")
def b_from_a_c(a, c, **kwargs):
    return a + c


@init("c")
def c_from_a_b(a, b, **kwargs):
    return a ** b


@init("d")
def d_from_a_b_c(a, b, c, **kwargs):
    return a ** b + c


@dataclass
class foo(object):
    a: float
    b: float
    c: float
    d: float

    @classmethod
    def init_from(cls, **kwargs):
        # Not sure if there is a better way to access the fields
        for field in foo.__dataclass_fields__:
            if field not in kwargs:
                funcs = init.mappings[field]

                # Multiple functions means a loop. If you're sure 
                # you have a 1-to-1 mapping then change the defaultdict 
                # to a dict[field->function]
                for func in funcs:
                    func_args = inspect.getfullargspec(func).args
                    
                    if all(arg in kwargs for arg in func_args):
                        kwargs[field] = func(**kwargs)
                        return foo(**kwargs)

Then use it:

>>> foo.init_from(a=3, b=2, d=3)
foo(a=3, b=2, c=9, d=3)

>>> foo.init_from(a=3, b=2, c=3)
foo(a=3, b=2, c=3, d=12)
Kostas Mouratidis
  • 1,145
  • 2
  • 17
  • 30
  • Sorry for the major edit.. the initial example was just too simple, although I did say that I wasn't going for such a trivial example but for a scalable solution. Anyway, thank you for your suggestion, it looks quite promising and I'll try it out asap – simfinite Jan 17 '21 at 09:18
  • Don't sweat it. I was a bit salty before the morning coffee. Another way I was thinking of dealing with this was to use some fancy `eval` or a "mapping" dictionary (e.g. `field -> tuple(func, params)` but it seemed extra work. Anyhow, let me know if you want me to add more details in the code or explanations about how it works – Kostas Mouratidis Jan 17 '21 at 10:07
0

Here is a solution that is based on the ideas of @kostas-mouratidis ideas of storing a mapping from fields to methods used to initialize these fields. By using a class decorator, the mapping can be stored with the class (where it belongs imho). By using another decorator for the methods to init fields, the resulting code looks quite clean and readable.

Any suggestions for improvements?

from dataclasses import dataclass
import inspect 

def dataclass_greedy_init(cls):
    """Dataclass decorator that adds an 'init_from' class method to recursively initialize 
    all fields and fully initialize an instance of the class from a given subset of
    fields specified as keyword arguments.

    In order to achieve this, the class is searched for *field init methods*, i.e. static 
    methods decoarted with the 'init_field' decorator. A mapping from field names to these 
    methods is built and stored as an attribute of the class. The 'init_from' method looks
    up appropriate methods given the set fields specified as keyword arguments in the 
    'init_from' class method. It initializes missing fields recursively in a greedy fashion,   
    i.e. it initializes the first missing field for which a field init method can be found
    and all arguments to this field init method can be supplied.  
    """    

    # Collect all field init methods
    init_methods = inspect.getmembers(cls, lambda f: inspect.isfunction(f) and not inspect.ismethod(f) and hasattr(f, "init_field"))
    # Create a mapping from field names to signatures (i.e. required fields)
    # and field init methods.
    cls.init_mapping = {}
    for init_method_name, init_method in init_methods:
        init_field = init_method.init_field
        if not init_field in cls.init_mapping:
            cls.init_mapping[init_field] = []
        cls.init_mapping[init_field].append((inspect.signature(init_method), init_method))
    # Add classmethod 'init_from'
    def init_from(cls, **kwargs):
        for field in cls.__dataclass_fields__:
            if field not in kwargs and field in cls.init_mapping:
                for init_method_sig, init_method in cls.init_mapping[field]:
                    try:
                        mapped_kwargs = {p: kwargs[p] for p in init_method_sig.parameters if p in kwargs}
                        bound_args = init_method_sig.bind(**mapped_kwargs)
                        bound_args.apply_defaults()
                        kwargs[field] = init_method(**bound_args.arguments)
                        return cls.init_from(**kwargs)
                    except TypeError:
                        pass
        return cls(**kwargs)
    cls.init_from = classmethod(init_from)
    return cls

def init_field(field_name):
    """Decorator to be used in combination with 'dataclass_greedy_init' to generate
    static methods with an additional 'field_name' attribute that indicates for which 
    of the dataclass's fields this method should be used during initialization."""
    def inner(func):
        func.init_field = field_name
        return staticmethod(func)
    return inner

@dataclass_greedy_init
@dataclass
class foo(object):
    a: float
    b: float
    c: float
    d: float

    @init_field("a")
    def init_a_from_b_c(b,c):
        return c-b

    @init_field("b")
    def init_b_from_a_c(a,c):
        return c-a

    @init_field("c")
    def init_c_from_a_b(a,b):
        return a+b

    @init_field("c")
    def init_c_from_d(d):
        return d/2

    @init_field("d")
    def init_d_from_a_b_c(a,b,c):
        return a+b+c

    @init_field("d")
    def init_d_from_a(a):
        return 6*a

print(foo.init_from(a=1, b=2))
print(foo.init_from(a=1, c=3))
print(foo.init_from(b=2, c=3))
print(foo.init_from(a=1))

simfinite
  • 59
  • 4