Search code examples
pythonscikit-learnclassificationtext-classification

How to correctly override and call super-method in Python


First, the problem at hand. I am writing a wrapper for a scikit-learn class, and am having problems with the right syntax. What I am trying to achieve is an override of the fit_transform function, which alters the input only slightly, and then calls its super-method with the new parameters:

from sklearn.feature_extraction.text import TfidfVectorizer

class TidfVectorizerWrapper(TfidfVectorizer):
    def __init__(self):
        TfidfVectorizer.__init__(self)  # is this even necessary?

    def fit_transform(self, x, y=None, **fit_params):
        x = [content.split('\t')[0] for content in x]  # filtering the input
        return TfidfVectorizer.fit_transform(self, x, y, fit_params)  
                            # this is the critical part, my IDE tells me for
                            # fit_params: 'unexpected arguments'

The Program crashes all over the place, starting with a Multiprocessing exception, not really telling me anything usefull. How do I correctly do this?

Additional info: The reason why I need to wrap it this way is because I use sklearn.pipeline.FeatureUnion to collect my feature extractors before putting them into a sklearn.pipeline.Pipeline. A consequence of doing it this way is, that I can only feed a single data set across all feature extractors -- but different extractors need different data. My solution was to feed the data in an easily separable format and filtering different parts in different extractors. If there is a better solution to this problem, I'd also be happy to hear it.

Edit 1: Adding ** to unpack the dict seems to not change anything: Screenshot

Edit 2: I just solved the remaining problem -- I needed to remove the constructor overload. Apparently, by trying to call the parent constructor, wishing to have all instance variables initiated correctly, I did the exact opposite. My wrapper had no idea what kind of parameters it can expect. Once I removed the superfluous call, everything worked out perfectly.


Solution

  • You forget to unpack fit_params which is passed as a dict and you want to pass it through as a keyword arguments which require unpacking operator **.

    from sklearn.feature_extraction.text import TfidfVectorizer
    
    class TidfVectorizerWrapper(TfidfVectorizer):
    
        def fit_transform(self, x, y=None, **fit_params):
            x = [content.split('\t')[0] for content in x]  # filtering the input
            return TfidfVectorizer.fit_transform(self, x, y, **fit_params)  
    

    one other thing that instaed of calling the TfidfVectorizer's fit_transform directly you can call the overloaded version through super method

    from sklearn.feature_extraction.text import TfidfVectorizer
    
    class TidfVectorizerWrapper(TfidfVectorizer):
    
        def fit_transform(self, x, y=None, **fit_params):
            x = [content.split('\t')[0] for content in x]  # filtering the input
            return super(TidfVectorizerWrapper, self).fit_transform(x, y, **fit_params)  
    

    To understand it check the following example

    def foo1(**kargs):
        print kargs
    
    def foo2(**kargs):
        foo1(**kargs)
        print 'foo2'
    
    def foo3(**kargs):
        foo1(kargs)
        print 'foo3'
    
    foo1(a=1, b=2)
    

    it prints the dictionary {'a': 1, 'b': 2}

    foo2(a=1, b=2)
    

    prints both dictionary and foo2, but

    foo3(a=1, b=2)
    

    raises error as we sent an positional argument equal to our dictionary to foo1, which does not accept such a thing. We could however do

    def foo4(**kargs):
        foo1(x=kargs)
        print 'foo4'
    

    which works fine, but prints a new dictionary {'x': {'a': 1, 'b': 2}}