4

First, the problem at hand. I am writing a wrapper for a scikit-learn class, and am having problems with the right syntax. What I am trying to achieve is an override of the fit_transform function, which alters the input only slightly, and then calls its super-method with the new parameters:

from sklearn.feature_extraction.text import TfidfVectorizer

class TidfVectorizerWrapper(TfidfVectorizer):
    def __init__(self):
        TfidfVectorizer.__init__(self)  # is this even necessary?

    def fit_transform(self, x, y=None, **fit_params):
        x = [content.split('\t')[0] for content in x]  # filtering the input
        return TfidfVectorizer.fit_transform(self, x, y, fit_params)  
                            # this is the critical part, my IDE tells me for
                            # fit_params: 'unexpected arguments'

The Program crashes all over the place, starting with a Multiprocessing exception, not really telling me anything usefull. How do I correctly do this?

Additional info: The reason why I need to wrap it this way is because I use sklearn.pipeline.FeatureUnion to collect my feature extractors before putting them into a sklearn.pipeline.Pipeline. A consequence of doing it this way is, that I can only feed a single data set across all feature extractors -- but different extractors need different data. My solution was to feed the data in an easily separable format and filtering different parts in different extractors. If there is a better solution to this problem, I'd also be happy to hear it.

Edit 1: Adding ** to unpack the dict seems to not change anything: Screenshot

Edit 2: I just solved the remaining problem -- I needed to remove the constructor overload. Apparently, by trying to call the parent constructor, wishing to have all instance variables initiated correctly, I did the exact opposite. My wrapper had no idea what kind of parameters it can expect. Once I removed the superfluous call, everything worked out perfectly.

Arne
  • 17,706
  • 5
  • 83
  • 99
  • 1
    did you try to replace `return TfidfVectorizer.fit_transform(self, x, y, fit_params)` with `return TfidfVectorizer.fit_transform(self, x, y, **fit_params)` ? – user3012759 Apr 22 '15 at 15:41
  • Also depending on the position of `y` in the methods signature it might be required to "name" `y`: `return TfidfVectorizer.fit_transform(self, x, y=y, **fit_params)` – Klaus D. Apr 22 '15 at 15:44
  • @KlausD. agreed, but fit_params is almost certainly not the right thing to pass in, as it is a dictionary – user3012759 Apr 22 '15 at 15:45
  • I tried both right now, adding `**fit_params` instead changed nothing, exact same IDE warning and runtime exception when trying it anyway. Adding y=y gives a syntax error, because fit_params is entered after a keyword. Changing it around gives another syntax error. – Arne Apr 22 '15 at 16:10

1 Answers1

6

You forget to unpack fit_params which is passed as a dict and you want to pass it through as a keyword arguments which require unpacking operator **.

from sklearn.feature_extraction.text import TfidfVectorizer

class TidfVectorizerWrapper(TfidfVectorizer):

    def fit_transform(self, x, y=None, **fit_params):
        x = [content.split('\t')[0] for content in x]  # filtering the input
        return TfidfVectorizer.fit_transform(self, x, y, **fit_params)  

one other thing that instaed of calling the TfidfVectorizer's fit_transform directly you can call the overloaded version through super method

from sklearn.feature_extraction.text import TfidfVectorizer

class TidfVectorizerWrapper(TfidfVectorizer):

    def fit_transform(self, x, y=None, **fit_params):
        x = [content.split('\t')[0] for content in x]  # filtering the input
        return super(TidfVectorizerWrapper, self).fit_transform(x, y, **fit_params)  

To understand it check the following example

def foo1(**kargs):
    print kargs

def foo2(**kargs):
    foo1(**kargs)
    print 'foo2'

def foo3(**kargs):
    foo1(kargs)
    print 'foo3'

foo1(a=1, b=2)

it prints the dictionary {'a': 1, 'b': 2}

foo2(a=1, b=2)

prints both dictionary and foo2, but

foo3(a=1, b=2)

raises error as we sent an positional argument equal to our dictionary to foo1, which does not accept such a thing. We could however do

def foo4(**kargs):
    foo1(x=kargs)
    print 'foo4'

which works fine, but prints a new dictionary {'x': {'a': 1, 'b': 2}}

lejlot
  • 64,777
  • 8
  • 131
  • 164