Search code examples
pythonparallel-processingpython-multiprocessingtartarfile

Can't map a function to tarfile members in parallel


I have a tarfile containing bz2-compressed files. I want to apply the function clean_file to each of the bz2 files, and collate the results. In series, this is easy with a loop:

import pandas as pd
import json
import os
import bz2
import itertools
import datetime
import tarfile
from multiprocessing import Pool

def clean_file(member):
    if '.bz2' in str(member):

        f = tr.extractfile(member)

        with bz2.open(f, "rt") as bzinput:
            dicts = []
            for i, line in enumerate(bzinput):
                line = line.replace('"name"}', '"name":" "}')
                dat = json.loads(line)
                dicts.append(dat)

        bzinput.close()
        f.close()
        del f, bzinput

        processed = dicts[0]
        return processed

    else:
        pass


# Open tar file and get contents (members)
tr = tarfile.open('data.tar')
members = tr.getmembers()
num_files = len(members)


# Apply the clean_file function in series
i=0
processed_files = []
for m in members:
    processed_files.append(clean_file(m))
    i+=1
    print('done '+str(i)+'/'+str(num_files))
    

However, I need to be able to do this in parallel. The method I'm trying uses Pool like so:

# Apply the clean_file function in parallel
if __name__ == '__main__':
   with Pool(2) as p:
      processed_files = list(p.map(clean_file, members))

But this returns an OSError:

Traceback (most recent call last):
  File "/Users/johnfoley/opt/anaconda3/envs/racing_env/lib/python3.6/multiprocessing/pool.py", line 119, in worker
    result = (True, func(*args, **kwds))
  File "parse_data.py", line 19, in clean_file
    for i, line in enumerate(bzinput):
  File "/Users/johnfoley/opt/anaconda3/envs/racing_env/lib/python3.6/bz2.py", line 195, in read1
    return self._buffer.read1(size)
  File "/Users/johnfoley/opt/anaconda3/envs/racing_env/lib/python3.6/_compression.py", line 68, in readinto
    data = self.read(len(byte_view))
  File "/Users/johnfoley/opt/anaconda3/envs/racing_env/lib/python3.6/_compression.py", line 103, in read
    data = self._decompressor.decompress(rawblock, size)
OSError: Invalid data stream
"""

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "parse_data.py", line 53, in <module>
    processed_files = list(tqdm.tqdm(p.imap(clean_file, members), total=num_files))
  File "/Users/johnfoley/opt/anaconda3/envs/racing_env/lib/python3.6/site-packages/tqdm/std.py", line 1167, in __iter__
    for obj in iterable:
  File "/Users/johnfoley/opt/anaconda3/envs/racing_env/lib/python3.6/multiprocessing/pool.py", line 735, in next
    raise value
OSError: Invalid data stream

So I guess this way isn't properly accessing the files from within data.tar or something. How can I apply the function in parallel?

I'm guessing this will work with any tar archive containing bz2 files but here's my data to reproduce the error: https://github.com/johnf1004/reproduce_tar_error


Solution

  • It seems some race condition was happening. Opening the tar file separately in every child process solves the issue:

    import json
    import bz2
    import tarfile
    import logging
    from multiprocessing import Pool
    
    
    def clean_file(member):
        if '.bz2' not in str(member):
            return
        try:
            with tarfile.open('data.tar') as tr:
                with tr.extractfile(member) as bz2_file:
                    with bz2.open(bz2_file, "rt") as bzinput:
                        dicts = []
                        for i, line in enumerate(bzinput):
                            line = line.replace('"name"}', '"name":" "}')
                            dat = json.loads(line)
                            dicts.append(dat)
                            return dicts[0]
        except Exception:
            logging.exception(f"Error while processing {member}")
    
    
    def process_serial():
        tr = tarfile.open('data.tar')
        members = tr.getmembers()
        processed_files = []
        for i, member in enumerate(members):
            processed_files.append(clean_file(member))
            print(f'done {i}/{len(members)}')
    
    
    def process_parallel():
        tr = tarfile.open('data.tar')
        members = tr.getmembers()
        with Pool() as pool:
            processed_files = pool.map(clean_file, members)
            print(processed_files)
    
    
    def main():
        process_parallel()
    
    
    if __name__ == '__main__':
        main()
    

    EDIT:

    Note that another way to solve this problem is to just use the spawn start method:

    multiprocessing.set_start_method('spawn')
    

    By doing this, we are instructing Python to "deep-copy" file handles in child processes. Under the default "fork" start method, the file handles of parent and child share the same offsets.