14

I am using sklearn's Pipeline and FunctionTransformer with a custom function

from sklearn.externals import joblib
from sklearn.preprocessing import FunctionTransformer
from sklearn.pipeline import Pipeline

This is my code:

def f(x):
    return x*2
pipe = Pipeline([("times_2", FunctionTransformer(f))])
joblib.dump(pipe, "pipe.joblib")
del pipe
del f
pipe = joblib.load("pipe.joblib") # Causes an exception

And I get this error:

AttributeError: module '__ main__' has no attribute 'f'

How can this be resolved ?

Note that this issue occurs also in pickle

Uri Goren
  • 13,386
  • 6
  • 58
  • 110

1 Answers1

8

I was able to hack a solution using the marshal module (in addition to pickle) and override the magic methods getstate and setstate used by pickle.

import marshal
from types import FunctionType
from sklearn.base import BaseEstimator, TransformerMixin

class MyFunctionTransformer(BaseEstimator, TransformerMixin):
    def __init__(self, f):
        self.func = f
    def __call__(self, X):
        return self.func(X)
    def __getstate__(self):
        self.func_name = self.func.__name__
        self.func_code = marshal.dumps(self.func.__code__)
        del self.func
        return self.__dict__
    def __setstate__(self, d):
        d["func"] = FunctionType(marshal.loads(d["func_code"]), globals(), d["func_name"])
        del d["func_name"]
        del d["func_code"]
        self.__dict__ = d
    def fit(self, X, y=None):
        return self
    def transform(self, X):
        return self.func(X)

Now, if we use MyFunctionTransformer instead of FunctionTransformer, the code works as expected:

from sklearn.externals import joblib
from sklearn.pipeline import Pipeline

@MyFunctionTransformer
def my_transform(x):
    return x*2
pipe = Pipeline([("times_2", my_transform)])
joblib.dump(pipe, "pipe.joblib")
del pipe
del my_transform
pipe = joblib.load("pipe.joblib")

The way this works, is by deleting the function f from the pickle, and instead marshaling its code, and its name.

dill also looks like a good alternative to marshaling

Uri Goren
  • 13,386
  • 6
  • 58
  • 110
  • It should be: del my_transform instead of del f. Would this still work with more than one custom function or nested pipelines? – KRKirov Jan 05 '19 at 20:52
  • 1
    True, thanks, I fixed the code snippet. It would work with nested pipelines and anything that's marshallable (not every function is) – Uri Goren Jan 05 '19 at 20:57
  • You do intend to load your pipeline in a separate script dont you? So even with your current method, won't you need to have the code of `MyFunctionTransformer` ready somewhere in your memory or imports before calling `joblib.load`? How is that better than having the code of the `function f` ready in imports. Maybe from another script? Am I missing something? – Vivek Kumar Jan 07 '19 at 09:37
  • Do you agree that if `FunctionTransformer` would be implemented with my additions (namely `setstate` and `getstate`) then pickling would include *all* the required dependencies for the pipeline ? – Uri Goren Jan 07 '19 at 12:07
  • Ok, so you want to change the scikit-learn api altogether. But in that case, to include all the dependencies, you would also need to take care of imports that may be used inside the function, like `np` or `pd` or any other module imported for use in function or else it will throw errors. – Vivek Kumar Jan 08 '19 at 12:36
  • Marshalling all of the dependencies (e.g. `pd`, `np` etc) falls beyond the scope of this StackOverflow question. It is possible with 3rd party libs such as `dill`. If you wish to talk about the technical details of the implementation please comment on: https://github.com/scikit-learn/scikit-learn/pull/12905 . P.S. Now that we are on the same page, I'd appreciate if you'd reverse the downvotes. Thanks! – Uri Goren Jan 08 '19 at 18:07
  • Unfortunately, I cant reverse my vote unless the answer is edited. – Vivek Kumar Jan 09 '19 at 14:30
  • Now coming to the issue, I tend to agree with scikit-learn devs here to not do this. (1) Because as they mentioned, if this happens at one place, it will need to be done in several other places which take a callable as input (and there are many) to maintain the API. – Vivek Kumar Jan 09 '19 at 14:36
  • (2) Even after all the efforts (if made), the solution you proposed is not complete as I said in my previous comment, the user will need to save the information about the other imports (or maybe other things) used in functions and will have to import/define those things before using the pickled function this way. So in my opinion this tends to be a use-case specific scenario, which `scikit-learn` should not aim for. – Vivek Kumar Jan 09 '19 at 14:38
  • Maybe I am beating around the bush, but first you ask if "I agree that pickling would include _all_ the required dependencies for the pipeline" and then say that "marshalling _all_ of the dependencies falls beyond the scope". Maybe I am critisizing too much ?? – Vivek Kumar Jan 09 '19 at 14:41
  • There are alot of edge-cases I do not want to cover, for example, if you marshal an object with one version of python and unmarshal in another version, there can be unexpected results. There are other pypi modules (e.g. `dill`) that aim to solve these issues, and my point is not to re-invent the wheel. – Uri Goren Jan 09 '19 at 21:25
  • 1
    Regarding `sklearn`, when you `pickle` a `TfidfVectotizer` transformer, you expect it to store the `vocab`, `tf` and `idf` in order to work. I think that `FunctionTransformer`, that its sole purpose is to wrap a function with a `transformer` should at least save this function, or raise a warning if that's not possible. P.S. I've edited my Q&A in light of this discussion. – Uri Goren Jan 09 '19 at 21:30