I have this piece of code that conditionally applies additional manipulations against a column ('NaicsDescription'
). I am not sure if it is the best / cleanest way I should be doing this in Polars. Even though it seems to work and would appreciate some guidance.
The idea is this to...
Is this correct?
import polars as pl
from unicodedata import normalize
from lingua import Language, LanguageDetectorBuilder
from nltk.corpus import stopwords
languages = [Language.ENGLISH, Language.FRENCH]
detector = LanguageDetectorBuilder.from_languages(*languages).build()
test = pl.DataFrame(
{
"id": [1, 2, 3, 4],
"NaicsDescription": ['Full service restuarants', 'The Manufacturing of toys and trains', 'POWER GENERATING STATIONS', 'the short term rental of cottages']
}
)
def lang_identifier(text: str):
'''
Language identification of text. This is just a simple wrapper around lingua function.
Parameters
----------
text : str
Input string to identify language for.
Returns
-------
language : str
Returns either 'EN' or 'FR' or None if input text is not of type string.
'''
language = None
if isinstance(text, str):
language = detector.detect_language_of(text)
if language:
language = language.iso_code_639_1.name
return language
en_stopwords = '|'.join(set(w.lower() for w in stopwords.words('english')))
fr_stopwords = '|'.join(set(w.lower() for w in stopwords.words('french')))
bilingual_stopwords = '|'.join(set(w.lower() for w in stopwords.words('english') + stopwords.words('french')))
test = test.with_columns(
pl.col('NaicsDescription')
.str.to_lowercase().alias('NaicsDescription_')
).with_columns(
pl.when(pl.col('NaicsDescription').map_elements(lang_identifier, return_dtype=pl.String) == 'FR')
.then(
pl.col('NaicsDescription_').str.replace_all(r'\b(?:' + fr_stopwords + r')\b', ' ')
)
.when(pl.col('NaicsDescription').map_elements(lang_identifier, return_dtype=pl.String) == 'EN')
.then(
pl.col('NaicsDescription_').str.replace_all(r'\b(?:' + en_stopwords + r')\b', ' ')
)
.otherwise(
pl.col('NaicsDescription_').str.replace_all(r'\b(?:' + bilingual_stopwords + r')\b', ' ')
)
).with_columns(
pl.col('NaicsDescription_').map_elements(lambda x: normalize('NFKD',x)
.encode('ascii', errors='ignore')
.decode('utf-8'), return_dtype=pl.String)
.str.replace_all(r'(?:[^\s\w]|_\d)+', ' ')
.str.replace_all(r'\b(?:\d+|\w{1,2})\b', ' ')
.str.replace_all(r'\s\s+', ' ')
.str.strip_chars()
.replace('', None)
)
edit: full working example added...
It is hard to say without seeing any data whether there is a better way to approach what you are doing or whether any steps can be optimised. In future, a runnable example does make it a lot easier to help.
It does seem like each step depends on the last. Generally avoid map_elements
unless your logic cannot be expressed in polars expressions. In this case, it seems like your lang_identifier
function would not likely be able to be expressed as a polars expression, but impossible to be certain. Your unicode normalization seems to be the best approach and in line with this SO answer.
I think more than anything, this code can be refactored so the intent is clearer and repeated operations such as identifying the language and removing stopwords are factored out to a function or a variable.
Here is my attempt at that. It shouldn't change anything, just a refactor. It probably isn't perfect as I haven't been able to run it.
def remove_stopwords(text: pl.Expr, stopwords: str) -> pl.Expr:
"""Removes supplied stopwords from a string."""
return text.str.replace_all(r"\b(?:" + stopwords + r")\b", " ")
naics_description = pl.col("NaicsDescription")
naics_description_lower = pl.col("NaicsDescription").str.to_lowercase()
identified_language = naics_description.map_elements(
lang_identifier,
return_dtype=pl.String,
)
naics_description_stopwords_removed = (
pl.when(identified_language == "FR")
.then(remove_stopwords(naics_description_lower, fr_stopwords))
.when(identified_language == "EN")
.then(remove_stopwords(naics_description_lower, en_stopwords))
.otherwise(remove_stopwords(naics_description_lower, bilingual_stopwords))
)
test = test.with_columns(
naics_description_stopwords_removed.map_elements(
lambda x: normalize("NFKD", x).encode("ascii", errors="ignore").decode("utf-8"),
return_dtype=pl.String,
)
.str.replace_all(r"(?:[^\s\w]|_\d)+", " ")
.str.replace_all(r"\b(?:\d+|\w{1,2})\b", " ")
.str.replace_all(r"\s\s+", " ")
.str.strip_chars()
.replace("", None)
)