Search code examples
pythonooppython-itertoolschaindatabase-cursor

Chain iterables into a new object but keep their class?


I am working with a cursor object in Python which is basically an iterator. I actually have two of them and need to chain into one. These objects have some methods that I need to use further in the program. The problem is that itertools.chain (which basically does the job) returns a chain object and I don't have access to the cursor methods.

Is there a way to keep its original class so that it is a new object (chained out of two) but having all the initial methods?


Solution

  • You can create some chaining logic yourself. Let's say you have two iterators of different classes, Foo and Bar:

    class Foo:
        def __init__(self, n):
            self._iter = ((i * i for i in range(n)))
        def __next__(self):
            return next(self._iter)
        def foo_method(self):
            print("Yay, method in Foo was called")
    
    
    class Bar:
        def __init__(self, n):
            self._iter = ((i / 2 for i in range(n)))
        def __next__(self):
            return next(self._iter)
        def bar_method(self):
            print("Yay, method in Bar was called")
    

    Foo iterators have .foo_method() and Bar iterators have .bar_method().

    Now let's chain them together:

    class Chain:
        def __init__(self, *iters):
            self._cursor = 0
            self._iters = iters
        def __next__(self):
            """
            Chain iterators together.
            """
            if self._cursor == len(self._iters):
                raise StopIteration
            try:
                return next(self._iters[self._cursor])
            except StopIteration:
                self._cursor += 1
                return next(self)
        def __getattr__(self, name):
            """
            Pass everything unknown to the current iterator in chain.
            """
            if self._cursor == len(self._iters):
                raise ValueError("No current iterator")
            return getattr(self._iters[self._cursor], name)
    

    Now if you do something like

    foo = Foo(3)
    bar = Bar(3)
    
    chain = Chain(foo, bar)
    
    print(next(chain))
    chain.foo_method()
    print(next(chain))
    print(next(chain))
    print(next(chain))
    print(next(chain))
    chain.bar_method()
    

    the output will be

    0
    Yay, method in Foo was called
    1
    4
    0.0
    0.5
    Yay, method in Bar was called
    

    This does not preserve the class of iterators, but it does allow you to access all the methods of the "current" iterator in chain.