Search code examples
pythongenerator

How to wrap a generator with a filter?


I have a series of connected generators and I want to create a filter that can be used to wrap one of the generators. This filter wrapper should take a generator and a function as parameters. If a data item in the incoming stream does not pass the requirements of the filter, it should be passed downstream to the next generator without going through the wrapped generator. I have made a working example here that should make it more clear as to what I am trying to achieve:

import functools

is_less_than_three = lambda x : True if x < 3 else False

def add_one(numbers):
    print("new generator created")
    for number in numbers:
        yield number + 1

def wrapper(generator1, filter_):
    @functools.wraps(generator1)
    def wrapped(generator2):
        for data in generator2:
            if filter_(data):
                yield from generator1([data])
            else:
                yield data
    return wrapped

add_one_to_numbers_less_than_three = wrapper(add_one, is_less_than_three)
answers = add_one_to_numbers_less_than_three(range(6))
for answer in answers:
    print(answer)

#new generator created
#1
#new generator created
#2
#new generator created
#3
#3
#4
#5

The problem with this is that it requires creating a new generator for each data item. There must be a better way? I have also tried using itertools.tee and splitting the generator, but this causes memory problems when the generators yield values at different rates (they do). How can I accomplish what the above code does without re-creating generators and without causing memory problems?

edited to add background information below

As input I will receive large video streams. The video streams may or may not end (could be a webcam). Users are able to choose which image processing steps are carried out on the video frames, thus the order and number of functions will change. Subsequently, the functions should be able to take each other's outputs as inputs.

I have accomplished this by using a series of generators. The input:output ratio of the generators/functions is variable - it could be 1:n, 1:1, or n:1 (for example, extracting several objects (subimages) from an image to be processed separately).

Currently these generators take a few parameters that are repeated among them (not DRY) and I am trying to decrease the number of parameters by refactoring them into separate generators or wrappers. One of the more difficult ones is a filter on the data stream to determine whether or not a function should be applied to the frame (the function could be cpu-intensive and not needed on all frames).

The number of parameters makes the usage of the function more difficult for the user to understand. It also makes it more difficult for me in that whenever I want to make a change to one of the common parameters, I have to edit it for all functions.

edit2 renamed function to generator in example code to make it more clear

edit3 the solution Thank you @Blckknght. This can be solved by creating an infinite iterator that passes the value of a local variable to the generator. I modified my example slightly to change add_one to a 1:n generator instead of a 1:1 generator to show how this solution can also work for 1:n generators.

import functools

is_less_than_three = lambda x : True if x < 3 else False

def add_one(numbers):
    print("new generator created")
    for number in numbers:
        if number == 0:
            yield number - 1
            yield number
        else:
            yield number

def wrapper(generator1, filter_):
    @functools.wraps(generator1)
    def wrapped(generator2):
        local_variable_passer = generator1(iter(lambda: data, object()))
        for data in generator2:
            if filter_(data):
                next_data = next(local_variable_passer)
                if data == 0:
                    yield next_data
                    next_data = next(local_variable_passer)
                    yield next_data
                else:
                    yield next_data
            else:
                yield data
    return wrapped

add_one_to_numbers_less_than_three = wrapper(add_one, is_less_than_three)
answers = add_one_to_numbers_less_than_three(range(6))
for answer in answers:
    print(answer)

#new generator created
#-1
#0
#1
#2
#3
#3
#4
#5

Solution

  • As I understand your problem, you have a stream of video frames, and you're trying to create a pipeline of processing functions that modify the stream. Different processing functions might change the number of frames, so a single input frame could result in multiple output frames, or multiple input frames could be consumed before a single output frame is produced. Some functions might be 1:1, but that's not something you can count on.

    Your current implementation uses generator functions for all the processing. The output function iterates on the chain, and each processing step in the pipeline requests frames from the one before it using iteration.

    The function you're trying to write right now is a sort of selective bypass. You want for some frames (those meeting some condition) to get passed in to some already existing generator function, but other frames to skip over the processing and just go directly into the output. Unfortunately, that's probably not possible to do with Python generators. The iteration protocol is just not sophisticated enough to support it.

    First off, it is possible to do this for 1:1 with generators, but you can't easily generalize to n:1 or 1:n cases. Here's what it might look like for 1:1:

    def selective_processing_1to1(processing_func, condition, input_iterable):
        processing_iterator = processing_func(iter(lambda: input_value, object()))
        for input_value in input_iterator:
            if condition(input_value):
                yield next(processing_iterator)
            else:
                yield input_value
    

    There's a lot of work being done in the processing_iterator creation step. By using the two-argument form of iter with a lambda function and a sentinel object (that will never be yielded), I'm creating an infinite iterator that always yields the current value of the local variable input_value. Then I pass that iterator it to the processing_func function. I can selectively call next on the generator object if I want to apply the processing the filter represents to the current value, or I can just yield the value myself without processing it.

    But because this only works on one frame at a time, it won't do for n:1 or 1:n filters (and I don't even want to think about m:n kinds of scenarios).

    A "peekable" iterator that lets you see what the next value is going to be before you iterate onto it might let you support a limited form of selective filtering for n:1 processes (that is, where a possibly-variable n input frames go into one output frame). The limitation is that you can only do the selective filtering on the first of the n frames that is going to be consumed by the processing, the others will get taken without you getting a chance to check them first. Maybe that's good enough?

    Anyway, here's what that looks like:

    _sentinel = object()
    class PeekableIterator:
        def __init__(self, input_iterable):
            self.iterator = iter(input_iterable)
            self.next_value = next(self.iterator, _sentinel)
    
        def __iter__(self):
            return self
    
        def __next__(self):
            if self.next_value != _sentinel:
                return_value = self.next_value
                self.next_value = next(self.iterator, _sentinel)
                return return_value
            raise StopIteration
    
        def peek(self):                 # this is not part of the iteration protocol!
            if self.next_value != _sentinel:
                return self.next_value
            raise ValueError("input exhausted")
    
    def selective_processing_Nto1(processing_func, condition, input_iterable):
        peekable = PeekableIterator(input_iterable)
        processing_iter = processing_func(peekable)
        while True:
            try:
                value = peekable.peek()
                print(value, condition(value))
            except ValueError:
                return
            try:
                yield next(processing_iter) if condition(value) else next(peekable)
            except StopIteration:
                return
    

    This is as good as we can practically do when the processing function is a generator. If we wanted to do more, such as supporting 1:n processing, we'd need some way to know how large the n was going to be, so we could get that many values before deciding if we will pass on the next input value or not. While you could write a custom class for the processing that would report that, it is probably less convenient than just calling the processing function repeatedly as you do in the question.