Search code examples
pythonpython-itertools

zip iterators asserting for equal length in python


I am looking for a nice way to zip several iterables raising an exception if the lengths of the iterables are not equal.

In the case where the iterables are lists or have a len method this solution is clean and easy:

def zip_equal(it1, it2):
    if len(it1) != len(it2):
        raise ValueError("Lengths of iterables are different")
    return zip(it1, it2)

However, if it1 and it2 are generators, the previous function fails because the length is not defined TypeError: object of type 'generator' has no len().

I imagine the itertools module offers a simple way to implement that, but so far I have not been able to find it. I have come up with this home-made solution:

def zip_equal(it1, it2):
    exhausted = False
    while True:
        try:
            el1 = next(it1)
            if exhausted: # in a previous iteration it2 was exhausted but it1 still has elements
                raise ValueError("it1 and it2 have different lengths")
        except StopIteration:
            exhausted = True
            # it2 must be exhausted too.
        try:
            el2 = next(it2)
            # here it2 is not exhausted.
            if exhausted:  # it1 was exhausted => raise
                raise ValueError("it1 and it2 have different lengths")
        except StopIteration:
            # here it2 is exhausted
            if not exhausted:
                # but it1 was not exhausted => raise
                raise ValueError("it1 and it2 have different lengths")
            exhausted = True
        if not exhausted:
            yield (el1, el2)
        else:
            return

The solution can be tested with the following code:

it1 = (x for x in ['a', 'b', 'c'])  # it1 has length 3
it2 = (x for x in [0, 1, 2, 3])     # it2 has length 4
list(zip_equal(it1, it2))           # len(it1) < len(it2) => raise
it1 = (x for x in ['a', 'b', 'c'])  # it1 has length 3
it2 = (x for x in [0, 1, 2, 3])     # it2 has length 4
list(zip_equal(it2, it1))           # len(it2) > len(it1) => raise
it1 = (x for x in ['a', 'b', 'c', 'd'])  # it1 has length 4
it2 = (x for x in [0, 1, 2, 3])          # it2 has length 4
list(zip_equal(it1, it2))                # like zip (or izip in python2)

Am I overlooking any alternative solution? Is there a simpler implementation of my zip_equal function?

Update:

  • Requiring python 3.10 or newer, see Asocia's answer
  • Thorough performance benchmarking and best performing solution on python<3.10: Stefan's answer
  • Simple answer without external dependencies: Martijn Pieters' answer (please check the comments for a bugfix in some corner cases)
  • More complex than Martijn's, but with better performance: cjerdonek's answer
  • If you don't mind a package dependency, see pylang's answer

Solution

  • A new solution even much faster than cjerdonek's on which it's based, and a benchmark. Benchmark first, my solution is green. Note that the "total size" is the same in all cases, two million values. The x-axis is the number of iterables. From 1 iterable with two million values, then 2 iterables with a million values each, all the way up to 100,000 iterables with 20 values each.

    benchmark plot

    The black one is Python's zip, I used Python 3.8 here so it doesn't do this question's task of checking for equal lengths, but I include it as reference/limit of the maximum speed one can hope for. You can see my solution is pretty close.

    For the perhaps most common case of zipping two iterables, mine's almost three times as fast as the previousy fastest solution by cjerdonek, and not much slower than zip. Times as text:

             number of iterables     1     2     3     4     5    10   100  1000 10000 50000 100000
    -----------------------------------------------------------------------------------------------
           more_itertools__pylang 209.3 132.1 105.8  93.7  87.4  74.4  54.3  51.9  53.9  66.9  84.5
       fillvalue__Martijn_Pieters 159.1 101.5  85.6  74.0  68.8  59.0  44.1  43.0  44.9  56.9  72.0
         chain_raising__cjerdonek  58.5  35.1  26.3  21.9  19.7  16.6  10.4  12.7  34.4 115.2 223.2
         ziptail__Stefan_Pochmann  10.3  12.4  10.4   9.2   8.7   7.8   6.7   6.8   9.4  22.6  37.8
                              zip  10.3   8.5   7.8   7.4   7.4   7.1   6.4   6.8   9.0  19.4  32.3
    

    My code (Try it online!):

    def zip_equal(*iterables):
    
        # For trivial cases, use pure zip.
        if len(iterables) < 2:
            return zip(*iterables)
    
        # Tail for the first iterable
        first_stopped = False
        def first_tail():
            nonlocal first_stopped 
            first_stopped = True
            return
            yield
    
        # Tail for the zip
        def zip_tail():
            if not first_stopped:
                raise ValueError('zip_equal: first iterable is longer')
            for _ in chain.from_iterable(rest):
                raise ValueError('zip_equal: first iterable is shorter')
                yield
    
        # Put the pieces together
        iterables = iter(iterables)
        first = chain(next(iterables), first_tail())
        rest = list(map(iter, iterables))
        return chain(zip(first, *rest), zip_tail())
    

    The basic idea is to let zip(*iterables) do all the work, and then after it stopped because some iterable was exhausted, check whether all iterables were equally long. They were if and only if:

    1. zip stopped because the first iterable didn't have another elements (i.e., no other iterable is shorter).
    2. None of the other iterables have any further elements (i.e., no other iterable is longer).

    How I check these criteria:

    • Since I need to check these criteria after zip ended, I can't return the zip object purely. Instead, I chain an empty zip_tail iterator behind it that does the checking.
    • To support checking the first criterion, I chain an empty first_tail iterator behind it whose sole job is to log that the first iterable's iteration stopped (i.e., it was asked for another element and it didn't have one, so the first_tail iterator was asked for one instead).
    • To support checking the second criterion, I fetch iterators for all the other iterables and keep them in a list before I give them to zip.

    Side note: more-itertools pretty much uses the same method as Martijn's, but does proper is checks instead of Martijn's not quite correct sentinel in combo. That's probably the main reason it's slower.

    Benchmark code (Try it online!):

    import timeit
    import itertools
    from itertools import repeat, chain, zip_longest
    from collections import deque
    from sys import hexversion, maxsize
    
    #-----------------------------------------------------------------------------
    # Solution by Martijn Pieters
    #-----------------------------------------------------------------------------
    
    def zip_equal__fillvalue__Martijn_Pieters(*iterables):
        sentinel = object()
        for combo in zip_longest(*iterables, fillvalue=sentinel):
            if sentinel in combo:
                raise ValueError('Iterables have different lengths')
            yield combo
    
    #-----------------------------------------------------------------------------
    # Solution by pylang
    #-----------------------------------------------------------------------------
    
    def zip_equal__more_itertools__pylang(*iterables):
        return more_itertools__zip_equal(*iterables)
    
    _marker = object()
    
    def _zip_equal_generator(iterables):
        for combo in zip_longest(*iterables, fillvalue=_marker):
            for val in combo:
                if val is _marker:
                    raise UnequalIterablesError()
            yield combo
    
    def more_itertools__zip_equal(*iterables):
        """``zip`` the input *iterables* together, but raise
        ``UnequalIterablesError`` if they aren't all the same length.
    
            >>> it_1 = range(3)
            >>> it_2 = iter('abc')
            >>> list(zip_equal(it_1, it_2))
            [(0, 'a'), (1, 'b'), (2, 'c')]
    
            >>> it_1 = range(3)
            >>> it_2 = iter('abcd')
            >>> list(zip_equal(it_1, it_2)) # doctest: +IGNORE_EXCEPTION_DETAIL
            Traceback (most recent call last):
            ...
            more_itertools.more.UnequalIterablesError: Iterables have different
            lengths
    
        """
        if hexversion >= 0x30A00A6:
            warnings.warn(
                (
                    'zip_equal will be removed in a future version of '
                    'more-itertools. Use the builtin zip function with '
                    'strict=True instead.'
                ),
                DeprecationWarning,
            )
        # Check whether the iterables are all the same size.
        try:
            first_size = len(iterables[0])
            for i, it in enumerate(iterables[1:], 1):
                size = len(it)
                if size != first_size:
                    break
            else:
                # If we didn't break out, we can use the built-in zip.
                return zip(*iterables)
    
            # If we did break out, there was a mismatch.
            raise UnequalIterablesError(details=(first_size, i, size))
        # If any one of the iterables didn't have a length, start reading
        # them until one runs out.
        except TypeError:
            return _zip_equal_generator(iterables)
    
    #-----------------------------------------------------------------------------
    # Solution by cjerdonek
    #-----------------------------------------------------------------------------
    
    class ExhaustedError(Exception):
        def __init__(self, index):
            """The index is the 0-based index of the exhausted iterable."""
            self.index = index
    
    def raising_iter(i):
        """Return an iterator that raises an ExhaustedError."""
        raise ExhaustedError(i)
        yield
    
    def terminate_iter(i, iterable):
        """Return an iterator that raises an ExhaustedError at the end."""
        return itertools.chain(iterable, raising_iter(i))
    
    def zip_equal__chain_raising__cjerdonek(*iterables):
        iterators = [terminate_iter(*args) for args in enumerate(iterables)]
        try:
            yield from zip(*iterators)
        except ExhaustedError as exc:
            index = exc.index
            if index > 0:
                raise RuntimeError('iterable {} exhausted first'.format(index)) from None
            # Check that all other iterators are also exhausted.
            for i, iterator in enumerate(iterators[1:], start=1):
                try:
                    next(iterator)
                except ExhaustedError:
                    pass
                else:
                    raise RuntimeError('iterable {} is longer'.format(i)) from None
                
    #-----------------------------------------------------------------------------
    # Solution by Stefan Pochmann
    #-----------------------------------------------------------------------------
    
    def zip_equal__ziptail__Stefan_Pochmann(*iterables):
    
        # For trivial cases, use pure zip.
        if len(iterables) < 2:
            return zip(*iterables)
    
        # Tail for the first iterable
        first_stopped = False
        def first_tail():
            nonlocal first_stopped 
            first_stopped = True
            return
            yield
    
        # Tail for the zip
        def zip_tail():
            if not first_stopped:
                raise ValueError(f'zip_equal: first iterable is longer')
            for _ in chain.from_iterable(rest):
                raise ValueError(f'zip_equal: first iterable is shorter')
                yield
    
        # Put the pieces together
        iterables = iter(iterables)
        first = chain(next(iterables), first_tail())
        rest = list(map(iter, iterables))
        return chain(zip(first, *rest), zip_tail())
    
    #-----------------------------------------------------------------------------
    # List of solutions to be speedtested
    #-----------------------------------------------------------------------------
    
    solutions = [
        zip_equal__more_itertools__pylang,
        zip_equal__fillvalue__Martijn_Pieters,
        zip_equal__chain_raising__cjerdonek,
        zip_equal__ziptail__Stefan_Pochmann,
        zip,
    ]
    
    def name(solution):
        return solution.__name__[11:] or 'zip'
    
    #-----------------------------------------------------------------------------
    # The speedtest code
    #-----------------------------------------------------------------------------
    
    def test(m, n):
        """Speedtest all solutions with m iterables of n elements each."""
    
        all_times = {solution: [] for solution in solutions}
        def show_title():
            print(f'{m} iterators of length {n:,}:')
        if verbose: show_title()
        def show_times(times, solution):
            print(*('%3d ms ' % t for t in times),
                  name(solution))
            
        for _ in range(3):
            for solution in solutions:
                times = sorted(timeit.repeat(lambda: deque(solution(*(repeat(i, n) for i in range(m))), 0), number=1, repeat=5))[:3]
                times = [round(t * 1e3, 3) for t in times]
                all_times[solution].append(times)
                if verbose: show_times(times, solution)
            if verbose: print()
            
        if verbose:
            print('best by min:')
            show_title()
            for solution in solutions:
                show_times(min(all_times[solution], key=min), solution)
            print('best by max:')
        show_title()
        for solution in solutions:
            show_times(min(all_times[solution], key=max), solution)
        print()
    
        stats.append((m,
                      [min(all_times[solution], key=min)
                       for solution in solutions]))
    
    #-----------------------------------------------------------------------------
    # Run the speedtest for several numbers of iterables
    #-----------------------------------------------------------------------------
    
    stats = []
    verbose = False
    total_elements = 2 * 10**6
    for m in 1, 2, 3, 4, 5, 10, 100, 1000, 10000, 50000, 100000:
        test(m, total_elements // m)
    
    #-----------------------------------------------------------------------------
    # Print the speedtest results for use in the plotting script
    #-----------------------------------------------------------------------------
    
    print('data for plotting by https://replit.com/@pochmann/zipequal-plot')
    names = [name(solution) for solution in solutions]
    print(f'{names = }')
    print(f'{stats = }')
    

    Code for plotting/table (also at Replit):

    import matplotlib.pyplot as plt
    
    names = ['more_itertools__pylang', 'fillvalue__Martijn_Pieters', 'chain_raising__cjerdonek', 'ziptail__Stefan_Pochmann', 'zip']
    stats = [(1, [[208.762, 211.211, 214.189], [159.568, 162.233, 162.24], [57.668, 58.94, 59.23], [10.418, 10.583, 10.723], [10.057, 10.443, 10.456]]), (2, [[130.065, 130.26, 130.52], [100.314, 101.206, 101.276], [34.405, 34.998, 35.188], [12.152, 12.473, 12.773], [8.671, 8.857, 9.395]]), (3, [[106.417, 107.452, 107.668], [90.693, 91.154, 91.386], [26.908, 27.863, 28.145], [10.457, 10.461, 10.789], [8.071, 8.157, 8.228]]), (4, [[97.547, 98.686, 98.726], [77.076, 78.31, 79.381], [23.134, 23.176, 23.181], [9.321, 9.4, 9.581], [7.541, 7.554, 7.635]]), (5, [[86.393, 88.046, 88.222], [68.633, 69.649, 69.742], [19.845, 20.006, 20.135], [8.726, 8.935, 9.016], [7.201, 7.26, 7.304]]), (10, [[70.384, 71.762, 72.473], [57.87, 58.149, 58.411], [15.808, 16.252, 16.262], [7.568, 7.57, 7.864], [6.732, 6.888, 6.911]]), (100, [[53.108, 54.245, 54.465], [44.436, 44.601, 45.226], [10.502, 11.073, 11.109], [6.721, 6.733, 6.847], [6.753, 6.774, 6.815]]), (1000, [[52.119, 52.476, 53.341], [42.775, 42.808, 43.649], [12.538, 12.853, 12.862], [6.802, 6.971, 7.002], [6.679, 6.724, 6.838]]), (10000, [[54.802, 55.006, 55.187], [45.981, 46.066, 46.735], [34.416, 34.672, 35.009], [9.485, 9.509, 9.626], [9.036, 9.042, 9.112]]), (50000, [[66.681, 66.98, 67.441], [56.593, 57.341, 57.631], [113.988, 114.022, 114.106], [22.088, 22.412, 22.595], [19.412, 19.431, 19.934]]), (100000, [[86.846, 88.111, 88.258], [74.796, 75.431, 75.927], [218.977, 220.182, 223.343], [38.89, 39.385, 39.88], [32.332, 33.117, 33.594]])]
    
    colors = {
        'more_itertools__pylang': 'm',
        'fillvalue__Martijn_Pieters': 'red',
        'chain_raising__cjerdonek': 'gold',
        'ziptail__Stefan_Pochmann': 'lime',
        'zip': 'black',
    }
    
    ns = [n for n, _ in stats]
    print('%28s' % 'number of iterables', *('%5d' % n for n in ns))
    print('-' * 95)
    x = range(len(ns))
    for i, name in enumerate(names):
        ts = [min(tss[i]) for _, tss in stats]
        color = colors[name]
        if color:
            plt.plot(x, ts, '.-', color=color, label=name)
            print('%29s' % name, *('%5.1f' % t for t in ts))
    plt.xticks(x, ns, size=9)
    plt.ylim(0, 133)
    plt.title('zip_equal(m iterables with 2,000,000/m values each)', weight='bold')
    plt.xlabel('Number of zipped *iterables* (not their lengths)', weight='bold')
    plt.ylabel('Time (for complete iteration) in milliseconds', weight='bold')
    plt.legend(loc='upper center')
    #plt.show()
    plt.savefig('zip_equal_plot.png', dpi=200)