I have custom class which inherits from functools.partial
from functools import partial
from typing import Callable
class CustomPartial(partial):
def __new__(cls, func_name: str, func: Callable, *args, **kwargs):
self=super(CustomPartial, cls).__new__(cls, func, *args, **kwargs)
self.func_name = func_name
return self
def __call__(self, *args, **kwargs):
return self.func(*args, **kwargs)
This code is fine as is for serial processing, i.e. I can create objects of this class as needed and call them as normal functions.
The issue I'm running into though is when I try and use one of these CustomPartial objects as the function input for joblib
Parallel processing. Based on the exceptions being throw
TypeError: CustomPartial.__new__() missing 1 required positional argument: 'func'
I've summarised that the issue is happening when trying to "un-serialize" between processes.
The code below is a minimum working example of the issue. I've tried to serialize using dill
and tried implementing the __setstate__
/ __getstate__
functions but nothing seems to be changing the exception being thrown.
import dill
from typing import Callable
from functools import partial
class CustomPartial(partial):
def __new__(cls, func_name: str, func: Callable, *args, **kwargs):
self=super(CustomPartial, cls).__new__(cls, func, *args, **kwargs)
self.func_name = func_name
return self
def __call__(self, *args, **kwargs):
return self.func(*args, **kwargs)
add = lambda x, y: x+y
add_ten = partial(add, y=10)
custom_partial = CustomPartial('add_ten', add_ten)
print(dill.loads(dill.dumps(add_ten)))
# functools.partial(<function <lambda> at 0x7f7647eefa30>, y=10)
try:
print(dill.loads(dill.dumps(custom_partial)))
except Error as err:
print(err)
# CustomPartial.__new__() missing 1 required positional argument: 'func'
Any help / direction towards resolving this issue would be greatly appreciated :)