Search code examples
pythonstreampython-itertoolschunking

python: is there a library function for chunking an input stream?


I want to chunk an input stream for batch processing. Given an input list or generator,

x_in = [1, 2, 3, 4, 5, 6 ...]

I want a function that will return chunks of that input. Say, if chunk_size=4, then,

x_chunked = [[1, 2, 3, 4], [5, 6, ...], ...]

This is something I do over and over, and was wondering if there is a more standard way than writing it myself. Am I missing something in itertools? (One could solve the problem with enumerate and groupby, but that feels clunky.) In case anyone wants to see an implementation, here it is,

def chunk_input_stream(input_stream, chunk_size):
    """partition a generator in a streaming fashion"""
    assert chunk_size >= 1
    accumulator = []
    for x in input_stream:
        accumulator.append(x)
        if len(accumulator) == chunk_size:
            yield accumulator
            accumulator = []
    if accumulator:
        yield accumulator

Edit

Inspired by kreativitea's answer, here's a solution with islice, which is straightforward & doesn't require post-filtering,

from itertools import islice

def chunk_input_stream(input_stream, chunk_size):
    while True:
        chunk = list(islice(input_stream, chunk_size))
        if chunk:
            yield chunk
        else:
            return

# test it with list(chunk_input_stream(iter([1, 2, 3, 4]), 3))

Solution

  • [Updated version thanks to the OP: I've been throwing yield from at everything in sight since I upgraded and it didn't even occur to me that I didn't need it here.]

    Oh, what the heck:

    from itertools import takewhile, islice, count
    
    def chunk(stream, size):
        return takewhile(bool, (list(islice(stream, size)) for _ in count()))
    

    which gives:

    >>> list(chunk((i for i in range(3)), 3))
    [[0, 1, 2]]
    >>> list(chunk((i for i in range(6)), 3))
    [[0, 1, 2], [3, 4, 5]]
    >>> list(chunk((i for i in range(8)), 3))
    [[0, 1, 2], [3, 4, 5], [6, 7]]
    

    Warning: the above suffers the same problem as the OP's chunk_input_stream if the input is a list. You could get around this with an extra iter() wrap but that's less pretty. Conceptually, using repeat or cycle might make more sense than count() but I was character-counting for some reason. :^)

    [FTR: no, I'm still not entirely serious about this, but hey-- it's a Monday.]