Search code examples
python-polars

Proper conditional chaining in Polars


I have this piece of code that conditionally applies additional manipulations against a column ('NaicsDescription'). I am not sure if it is the best / cleanest way I should be doing this in Polars. Even though it seems to work and would appreciate some guidance.

The idea is this to...

  1. apply the lowercase operation on the column,
  2. conditionally apply some additional operations,
  3. and finally finish up with some additional common operations.

Is this correct?

import polars as pl
from unicodedata import normalize 
from lingua import Language, LanguageDetectorBuilder
from nltk.corpus import stopwords

languages = [Language.ENGLISH, Language.FRENCH]
detector = LanguageDetectorBuilder.from_languages(*languages).build()

test = pl.DataFrame(
    {
        "id": [1, 2, 3, 4],
        "NaicsDescription": ['Full service restuarants', 'The Manufacturing of toys and trains', 'POWER GENERATING STATIONS', 'the short term rental of cottages']
    }
)


def lang_identifier(text: str):
    '''
    Language identification of text. This is just a simple wrapper around lingua function.

    Parameters
    ----------
    text : str
        Input string to identify language for.

    Returns
    -------
    language : str
        Returns either 'EN' or 'FR' or None if input text is not of type string.

    '''
    language = None
    if isinstance(text, str):
        language = detector.detect_language_of(text)
        if language:
            language = language.iso_code_639_1.name
    return language


en_stopwords = '|'.join(set(w.lower() for w in stopwords.words('english')))
fr_stopwords = '|'.join(set(w.lower() for w in stopwords.words('french')))
bilingual_stopwords = '|'.join(set(w.lower() for w in stopwords.words('english') + stopwords.words('french')))

test = test.with_columns(
pl.col('NaicsDescription')
    .str.to_lowercase().alias('NaicsDescription_')
    ).with_columns(
        pl.when(pl.col('NaicsDescription').map_elements(lang_identifier, return_dtype=pl.String) == 'FR')
        .then(
            pl.col('NaicsDescription_').str.replace_all(r'\b(?:' + fr_stopwords + r')\b', ' ')
            )
        .when(pl.col('NaicsDescription').map_elements(lang_identifier, return_dtype=pl.String) == 'EN')
        .then(
            pl.col('NaicsDescription_').str.replace_all(r'\b(?:' + en_stopwords + r')\b', ' ')
            )
        .otherwise(
            pl.col('NaicsDescription_').str.replace_all(r'\b(?:' + bilingual_stopwords + r')\b', ' ')
            )
        ).with_columns(
            pl.col('NaicsDescription_').map_elements(lambda x: normalize('NFKD',x)
                        .encode('ascii', errors='ignore')
                        .decode('utf-8'), return_dtype=pl.String)
            .str.replace_all(r'(?:[^\s\w]|_\d)+', ' ')
            .str.replace_all(r'\b(?:\d+|\w{1,2})\b', ' ')
            .str.replace_all(r'\s\s+', ' ')
            .str.strip_chars()
            .replace('', None)
    )

edit: full working example added...


Solution

  • It is hard to say without seeing any data whether there is a better way to approach what you are doing or whether any steps can be optimised. In future, a runnable example does make it a lot easier to help.

    It does seem like each step depends on the last. Generally avoid map_elements unless your logic cannot be expressed in polars expressions. In this case, it seems like your lang_identifier function would not likely be able to be expressed as a polars expression, but impossible to be certain. Your unicode normalization seems to be the best approach and in line with this SO answer.

    I think more than anything, this code can be refactored so the intent is clearer and repeated operations such as identifying the language and removing stopwords are factored out to a function or a variable.

    Here is my attempt at that. It shouldn't change anything, just a refactor. It probably isn't perfect as I haven't been able to run it.

    def remove_stopwords(text: pl.Expr, stopwords: str) -> pl.Expr:
        """Removes supplied stopwords from a string."""
        return text.str.replace_all(r"\b(?:" + stopwords + r")\b", " ")
    
    
    naics_description = pl.col("NaicsDescription")
    naics_description_lower = pl.col("NaicsDescription").str.to_lowercase()
    
    identified_language = naics_description.map_elements(
        lang_identifier,
        return_dtype=pl.String,
    )
    
    naics_description_stopwords_removed = (
        pl.when(identified_language == "FR")
        .then(remove_stopwords(naics_description_lower, fr_stopwords))
        .when(identified_language == "EN")
        .then(remove_stopwords(naics_description_lower, en_stopwords))
        .otherwise(remove_stopwords(naics_description_lower, bilingual_stopwords))
    )
    
    test = test.with_columns(
        naics_description_stopwords_removed.map_elements(
            lambda x: normalize("NFKD", x).encode("ascii", errors="ignore").decode("utf-8"),
            return_dtype=pl.String,
        )
        .str.replace_all(r"(?:[^\s\w]|_\d)+", " ")
        .str.replace_all(r"\b(?:\d+|\w{1,2})\b", " ")
        .str.replace_all(r"\s\s+", " ")
        .str.strip_chars()
        .replace("", None)
    )