Search code examples
pythonoptimizationcountry

problems optimizing a function that finds countries mentioned in text in order for it to run faster


I created a function that combs strings for mentions of countries. This is based on a .txt file that contains many different way people mention a country in the text. The file looks like this:

"afghanistan": ["afghan", "afghans"], "albania": ["albanian", "albanians"], "algeria": ["algerian", "algerians"], "angola": ["angolan", "angolans"], ... and so on, for every country on earth.

I then created a function that combs the string and searches for the mentions - but it runs a bit slow on large datasets, and i really want to make the function run faster - but I don't know how.

The function looks like this:



import json

import string

from re import sub

from typing import List, Union

 

def find_countries(text: str, exclude: Union[str, List[str]] = [], extra: Union[str, List[str]] = []) -> Union[List[str], str]:

   """

   Parameters

   ----------

   `text` : `str`

       The text to extract countries from.

   `exclude` : `list or str`

       Optional. Countries to exclude from search.

   `extra` : `list or str`

       Optional. Additional terms to search for (usually orgs).

   """

 

   # Load country names from file

   with open('country_names.txt') as file:

       country_names = json.load(file)

 

   # Convert 'exclude' and 'extra' to lists

   exclude = [exclude] if isinstance(exclude, str) else exclude

   extra = [extra] if isinstance(extra, str) else extra

 

   # Include 'extra' countries or orgs

   for i in extra:

       country_names[i.lower()] = []

 

   # Remove 'exclude' countries using set operations

   exclude_set = set(exclude)

   countries = {country for country in country_names.keys() if country.lower() not in exclude_set}

 

   # Clean and preprocess the input text

   my_punct = string.punctuation + '”“'

   replace_punct_string = "['’-]"

   text = sub(replace_punct_string, " ", text)

   text = text.translate(str.maketrans('', '', my_punct)).lower()

 

   #Search for country mentions using a set comprehension

   countries_mentioned = {country for country in countries

                             if any(f' {name} ' in f' {text} ' for name in {country} | set(country_names[country]))}

 

   return list(countries_mentioned)

The function recieves a string and combs it for mentions of countries, which it then returns as a list of countries. I usually apply it to a Pandas Series.

I think that code as it is now is "fine" - it isn't long and it does the job. I wonder and hope that you can help me make it run faster so that when i apply it to tens of thousands of texts it wont years to finish. Also - any tips on writing better code will help a lot!


Solution

  • You do a lot of converting on-the-fly which seems to me completely unnecessary. You really should provide things as sets if you use only set functionality. If I'm seeing this correctly you don't need the ordering of the list so just fill sets into the arguments rather than lists. With this you can save all the conversion stuff inside the function.

    Additionally, if the file is not too large and you are using the function a lot of times, you can save much performance by loading the data only once globally and saving it in memory instead of reloading it all the time inside the function. You could e.g. create a data structure which loads those data automatically and caches it to prevent reloads. The @property decorator is well-suited for such use-cases.

    I would also create a dictionary which maps the variants to the correct value. Something like

    {
        "afghan": "afghanistan",
        "afghans": "afghanistan",
        # ...
    }
    

    With this you can save / outsource one loop in your function.

    One warning though: You should almost never use an empty list in the argument list as default value. Here is why - found at this SO Post

    Edit

    Actually, this "flipped" dictionary is not helpful. As you mentioned there was also a problem with matching subwords e.g. Oman in woman. You can prevent this and eventually even speed things up a bit using regex (I don't actually know, didn't do a performance test).

    import itertools
    import json
    from typing import Optional, Iterable
    from regex import regex
    
    
    class CountryProvider:
        def __init__(self):
            self._countries: Optional[set[str]] = None
            self._patterns: Optional[dict[str, regex.Pattern]] = None
    
        def _load_countries(self):
            with open("country_names.txt") as file:
                countries = json.load(file)
            self._patterns = {
                country: regex.compile(
                    rf"\b({country}|" + "|".join(variants) + r")\b", regex.IGNORECASE
                )
                for country, variants in countries
            }
            self._countries = set(countries.keys())
    
        @property
        def countries(self) -> set[str]:
            if self._countries is None:
                self._load_countries()
            return self._countries
    
        @property
        def patterns(self) -> dict[str, regex.Pattern]:
            if self._patterns is None:
                self._load_countries()
            return self._patterns
    
    
    COUNTRY_PROVIDER = CountryProvider()
    
    
    def find_countries(
        text: str,
        exclude: Optional[list[str]] = None,
        extra: Optional[dict[str, list[str]]] = None,
    ) -> list[str]:
        # preprocess text input
        # set empty list for exclude and extra if they are None
        countries = COUNTRY_PROVIDER.countries
        patterns = COUNTRY_PROVIDER.patterns
        extra_patterns = {
            country: regex.compile(
                rf"\b({country}|" + "|".join(variants) + r")\b", regex.IGNORECASE
            )
            for country, variants in extra
            if country not in exclude
        }
        mentioned_countries: list[str] = []
        for country in countries:
            if country in exclude:
                continue
            if regex.search(patterns[country], text, regex.IGNORECASE) is not None:
                mentioned_countries.append(country)
        for country in extra:
            if regex.search(extra_patterns[country], text, regex.IGNORECASE) is not None:
                mentioned_countries.append(country)
        return mentioned_countries
    

    Note that the patterns dictionary contains a regex pattern for each country which should match all variants.