Search code examples
python-3.xgeneratorwith-statement

Returning a generator in a with statement


I wanted to create a wrapper function over pandas.read_csv to change the default separator and format the file a specific way. This is the code I had :

def custom_read(path, sep="|", **kwargs):
    if not kwargs.get("chunksize", False):
        df_ = pd.read_csv(path, sep=sep, **kwargs)
        return format_df(df_, path)
    else:
        with pd.read_csv(path, sep=sep, **kwargs) as reader:
            return (format_df(chunk, path) for chunk in reader)

It turns out that this segfaults when used like so :

L = [chunk.iloc[:10, :] for chunk in custom_read(my_file)]

From what I understood off the backtrace, the generator is created, then the file is closed and the segfault happens when the generator tries to read from the now closed file.

I could avoid the segfault with a minor refactoring :

def custom_read(path, sep="|", **kwargs):
    if not kwargs.get("chunksize", False):
        df_ = pd.read_csv(path, sep=sep, **kwargs)
        return format_df(df_, path)
    else:
        reader = pd.read_csv(path, sep=sep, **kwargs)
        return (format_df(chunk, path) for chunk in reader)

I couldn't find anything on the particular usecase of generators in with clauses, is it something to avoid ? Is this supposed not to work or is this a bug of some kind ?

Is there a way to avoid this error but still use the encouraged with statement ?


Solution

  • You could use a generator which keeps the file open. See the following example:

    import os
    
    def lines_format(lines):
        return "\n".join(f"*{line.strip()}*" for line in lines)
    
    def chunk_gen(file, chunksize):
        with open(file, mode='r') as f:
            while True:
                lines = f.readlines(chunksize)
                if not lines:
                    break
                yield lines_format(lines)
        
    def get_formatted_pages(file, chunksize=0):
        if chunksize > 0:
            return chunk_gen(file, chunksize)
        else:
            with open(file, mode='r') as f:
                lines = f.readlines()
                return [lines_format(lines)]
                    
    with open("abc.txt", mode='w') as f:
        f.write(os.linesep.join('abc'))
        
    pages = get_formatted_pages("abc.txt")
    for i, page in enumerate(pages, start=1):
        print(f"Page {i}")
        print(page)
        
    pages = get_formatted_pages("abc.txt", chunksize=2)
    for i, page in enumerate(pages, start=1):
        print(f"Page {i}")
        print(page)
    

    Edit: In your pandas.read_csv use case, this would look like

    import pandas as pd
    
    df = pd.DataFrame({'char': list('abc'), "num": range(3)})
    df.to_csv('abc.csv')
    
    def gen_chunk(file, chunksize):
        with pd.read_csv(file, chunksize=chunksize, index_col=0) as reader:
            for chunk in reader:
                yield format_df(chunk)
                
    def format_df(df):
        # do something
        df['char'] = df['char'].str.capitalize()
        return df
        
    def get_formatted_pages(file, chunksize=0):
        if chunksize > 0:
            return gen_chunk(file, chunksize)
        else:
            return [format_df(pd.read_csv(file, index_col=0))]
        
    list(get_formatted_pages('abc.csv', chunksize=2))