Search code examples
pythonapache-sparkpysparkgoogle-cloud-dataproc

How can I load data that can't be pickled in each Spark executor?


I'm using the NoAho library which is written in Cython. Its internal trie cannot be pickled: if I load it on the master node, I never get matches for operations that execute in workers.

Since I would like to use the same trie in each Spark executor, I found a way to load the trie lazily, inspired by this spaCy on Spark issue.

global trie

def get_match(text):
    # 1. Load trie if needed
    global trie
    try:
        trie
    except NameError:
        from noaho import NoAho

        trie = NoAho()
        trie.add(key_text='ms windows', payload='Windows 2000')
        trie.add(key_text='ms windows 2000', payload='Windows 2000')
        trie.add(key_text='windows 2k', payload='Windows 2000')
        ...

    # 2. Find an actual match to get they payload back
    return trie.findall_long(text)

While this works, all .add() calls are performed for every Spark job, which takes around one minute. Since I'm not sure "Spark job" is the correct term, I'll be more explicit: I use Spark in a Jupyter notebook, and every time I run a cell that needs the get_match() function, the trie is never cached and takes one minute to load the tries, which dominates the run time.

Is there anything I can do to ensure the trie gets cached? Or is there a better solution to my problem?


Solution

  • One thing you can try is to use a singleton module to load and initialize the trie. Basically all you need is a separate module with something like this:

    • trie_loader.py

      from noaho import NoAho
      
      def load():
          trie = NoAho()
          trie.add('ms windows', 'Windows 2000')
          trie.add('ms windows 2000', 'Windows 2000')
          trie.add('windows 2k', 'Windows 2000')
          return trie
      
      trie  = load()
      

    and distribute this using standard Spark tools:

    sc.addPyFile("trie_loader.py")
    import trie_loader
    
    rdd = sc.parallelize(["ms windows", "Debian GNU/Linux"])
    rdd.map(lambda x: (x, trie_loader.trie.find_long(x))).collect()
    ## [('ms windows', (0, 10, 'Windows 2000')),
    ##  ('Debian GNU/Linux', (None, None, None))]
    

    This should load required data every time Python process executor is started instead of loading it when data is accessed. I am not sure if it can help here but it is worth a try.