Search code examples
pythonregexscikit-learnlogistic-regressioncountvectorizer

CountVectorizer preprocessing related to regex


I am doing text processing using CountVectorizer/logistic regression and comparing f1 score for no preprocessing vs preprocessing. I would like to use regex for preprocessing, so I built a code like below

def better_preprocessor(s):
    lower = s.lower()
    lower = re.sub(r'^\w{8,}$', lambda x:x[:7], lower)
    return lower

def a():
    cv = CountVectorizer()
    train = cv.fit_transform(train_data)
    features = cv.get_feature_names()
    cv_dev = CountVectorizer(vocabulary = features)
    dev = cv_dev.fit_transform(dev_data)
    print(features)

    lgr = LogisticRegression(C=0.5, solver="liblinear", multi_class="auto")
    lgr.fit(train, train_labels)
    lgr_pred = lgr.predict(dev)
    score = metrics.f1_score(dev_labels, lgr_pred, average="weighted")
    print('No preprocessing score:', score)

    cv_im = CountVectorizer(preprocessor=better_preprocessor)
    train_im = cv_im.fit_transform(train_data)
    features_im = cv_im.get_feature_names()
    cv_im_dev = CountVectorizer(preprocessor=better_preprocessor, vocabulary = features_im)
    dev_im = cv_im_dev.fit_transform(dev_data)

    lgr.fit(train_im, train_labels)
    lgr_pred_im = lgr.predict(dev_im)
    score_im = metrics.f1_score(dev_labels, lgr_pred_im, average="weighted")
    print('Preprocessing score', score_im)
    print(len(features)-len(features_im))
    print(features_im)

a()

I tried to truncate word length larger or equal to 8 into 7, but when I checked vocabulary list using get_feature_names, there was no change. I don't know where I should fix this.


Solution

  • You do not need any regex for this. Use

    def better_preprocessor(s):
        if len(s) >= 8:
            return s.lower()[:7]
        else:
            return s.lower()
    

    The re.sub(r'^\w{8,}$', lambda x:x[:7], lower) code takes the lower string and tries to match ^\w{8,}$:

    • ^ - start of string
    • \w{8,} - eight or more word chars
    • $ - end of string.

    The lambda x:x[:7] then tries to take the match (where x is a match data object) and you try to slice the match data object. Probably you meant to use x.group()[:7], but it is still an overkill here.

    If you plan to extract all words from a string and truncate them, you need to specify what a word is for you and use

    def better_preprocessor(s):
        return re.sub(r'\b(\w{7})\w+', r'\1', s.lower())
    

    See the regex demo

    • \b - word boundary
    • (\w{7}) - Group 1 (referred to with \1 from the replacement pattern): seven word chars
    • \w+ - 1+ word chars