I am looking for a nice way to zip
several iterables raising an exception if the lengths of the iterables are not equal.
In the case where the iterables are lists or have a len
method this solution is clean and easy:
def zip_equal(it1, it2):
if len(it1) != len(it2):
raise ValueError("Lengths of iterables are different")
return zip(it1, it2)
However, if it1
and it2
are generators, the previous function fails because the length is not defined TypeError: object of type 'generator' has no len()
.
I imagine the itertools
module offers a simple way to implement that, but so far I have not been able to find it. I have come up with this home-made solution:
def zip_equal(it1, it2):
exhausted = False
while True:
try:
el1 = next(it1)
if exhausted: # in a previous iteration it2 was exhausted but it1 still has elements
raise ValueError("it1 and it2 have different lengths")
except StopIteration:
exhausted = True
# it2 must be exhausted too.
try:
el2 = next(it2)
# here it2 is not exhausted.
if exhausted: # it1 was exhausted => raise
raise ValueError("it1 and it2 have different lengths")
except StopIteration:
# here it2 is exhausted
if not exhausted:
# but it1 was not exhausted => raise
raise ValueError("it1 and it2 have different lengths")
exhausted = True
if not exhausted:
yield (el1, el2)
else:
return
The solution can be tested with the following code:
it1 = (x for x in ['a', 'b', 'c']) # it1 has length 3
it2 = (x for x in [0, 1, 2, 3]) # it2 has length 4
list(zip_equal(it1, it2)) # len(it1) < len(it2) => raise
it1 = (x for x in ['a', 'b', 'c']) # it1 has length 3
it2 = (x for x in [0, 1, 2, 3]) # it2 has length 4
list(zip_equal(it2, it1)) # len(it2) > len(it1) => raise
it1 = (x for x in ['a', 'b', 'c', 'd']) # it1 has length 4
it2 = (x for x in [0, 1, 2, 3]) # it2 has length 4
list(zip_equal(it1, it2)) # like zip (or izip in python2)
Am I overlooking any alternative solution? Is there a simpler implementation of my zip_equal
function?
Update:
A new solution even much faster than cjerdonek's on which it's based, and a benchmark. Benchmark first, my solution is green. Note that the "total size" is the same in all cases, two million values. The x-axis is the number of iterables. From 1 iterable with two million values, then 2 iterables with a million values each, all the way up to 100,000 iterables with 20 values each.
The black one is Python's zip
, I used Python 3.8 here so it doesn't do this question's task of checking for equal lengths, but I include it as reference/limit of the maximum speed one can hope for. You can see my solution is pretty close.
For the perhaps most common case of zipping two iterables, mine's almost three times as fast as the previousy fastest solution by cjerdonek, and not much slower than zip
. Times as text:
number of iterables 1 2 3 4 5 10 100 1000 10000 50000 100000
-----------------------------------------------------------------------------------------------
more_itertools__pylang 209.3 132.1 105.8 93.7 87.4 74.4 54.3 51.9 53.9 66.9 84.5
fillvalue__Martijn_Pieters 159.1 101.5 85.6 74.0 68.8 59.0 44.1 43.0 44.9 56.9 72.0
chain_raising__cjerdonek 58.5 35.1 26.3 21.9 19.7 16.6 10.4 12.7 34.4 115.2 223.2
ziptail__Stefan_Pochmann 10.3 12.4 10.4 9.2 8.7 7.8 6.7 6.8 9.4 22.6 37.8
zip 10.3 8.5 7.8 7.4 7.4 7.1 6.4 6.8 9.0 19.4 32.3
My code (Try it online!):
def zip_equal(*iterables):
# For trivial cases, use pure zip.
if len(iterables) < 2:
return zip(*iterables)
# Tail for the first iterable
first_stopped = False
def first_tail():
nonlocal first_stopped
first_stopped = True
return
yield
# Tail for the zip
def zip_tail():
if not first_stopped:
raise ValueError('zip_equal: first iterable is longer')
for _ in chain.from_iterable(rest):
raise ValueError('zip_equal: first iterable is shorter')
yield
# Put the pieces together
iterables = iter(iterables)
first = chain(next(iterables), first_tail())
rest = list(map(iter, iterables))
return chain(zip(first, *rest), zip_tail())
The basic idea is to let zip(*iterables)
do all the work, and then after it stopped because some iterable was exhausted, check whether all iterables were equally long. They were if and only if:
zip
stopped because the first iterable didn't have another elements (i.e., no other iterable is shorter).How I check these criteria:
zip
ended, I can't return the zip
object purely. Instead, I chain an empty zip_tail
iterator behind it that does the checking.first_tail
iterator behind it whose sole job is to log that the first iterable's iteration stopped (i.e., it was asked for another element and it didn't have one, so the first_tail
iterator was asked for one instead).zip
.Side note: more-itertools pretty much uses the same method as Martijn's, but does proper is
checks instead of Martijn's not quite correct sentinel in combo
. That's probably the main reason it's slower.
Benchmark code (Try it online!):
import timeit
import itertools
from itertools import repeat, chain, zip_longest
from collections import deque
from sys import hexversion, maxsize
#-----------------------------------------------------------------------------
# Solution by Martijn Pieters
#-----------------------------------------------------------------------------
def zip_equal__fillvalue__Martijn_Pieters(*iterables):
sentinel = object()
for combo in zip_longest(*iterables, fillvalue=sentinel):
if sentinel in combo:
raise ValueError('Iterables have different lengths')
yield combo
#-----------------------------------------------------------------------------
# Solution by pylang
#-----------------------------------------------------------------------------
def zip_equal__more_itertools__pylang(*iterables):
return more_itertools__zip_equal(*iterables)
_marker = object()
def _zip_equal_generator(iterables):
for combo in zip_longest(*iterables, fillvalue=_marker):
for val in combo:
if val is _marker:
raise UnequalIterablesError()
yield combo
def more_itertools__zip_equal(*iterables):
"""``zip`` the input *iterables* together, but raise
``UnequalIterablesError`` if they aren't all the same length.
>>> it_1 = range(3)
>>> it_2 = iter('abc')
>>> list(zip_equal(it_1, it_2))
[(0, 'a'), (1, 'b'), (2, 'c')]
>>> it_1 = range(3)
>>> it_2 = iter('abcd')
>>> list(zip_equal(it_1, it_2)) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
more_itertools.more.UnequalIterablesError: Iterables have different
lengths
"""
if hexversion >= 0x30A00A6:
warnings.warn(
(
'zip_equal will be removed in a future version of '
'more-itertools. Use the builtin zip function with '
'strict=True instead.'
),
DeprecationWarning,
)
# Check whether the iterables are all the same size.
try:
first_size = len(iterables[0])
for i, it in enumerate(iterables[1:], 1):
size = len(it)
if size != first_size:
break
else:
# If we didn't break out, we can use the built-in zip.
return zip(*iterables)
# If we did break out, there was a mismatch.
raise UnequalIterablesError(details=(first_size, i, size))
# If any one of the iterables didn't have a length, start reading
# them until one runs out.
except TypeError:
return _zip_equal_generator(iterables)
#-----------------------------------------------------------------------------
# Solution by cjerdonek
#-----------------------------------------------------------------------------
class ExhaustedError(Exception):
def __init__(self, index):
"""The index is the 0-based index of the exhausted iterable."""
self.index = index
def raising_iter(i):
"""Return an iterator that raises an ExhaustedError."""
raise ExhaustedError(i)
yield
def terminate_iter(i, iterable):
"""Return an iterator that raises an ExhaustedError at the end."""
return itertools.chain(iterable, raising_iter(i))
def zip_equal__chain_raising__cjerdonek(*iterables):
iterators = [terminate_iter(*args) for args in enumerate(iterables)]
try:
yield from zip(*iterators)
except ExhaustedError as exc:
index = exc.index
if index > 0:
raise RuntimeError('iterable {} exhausted first'.format(index)) from None
# Check that all other iterators are also exhausted.
for i, iterator in enumerate(iterators[1:], start=1):
try:
next(iterator)
except ExhaustedError:
pass
else:
raise RuntimeError('iterable {} is longer'.format(i)) from None
#-----------------------------------------------------------------------------
# Solution by Stefan Pochmann
#-----------------------------------------------------------------------------
def zip_equal__ziptail__Stefan_Pochmann(*iterables):
# For trivial cases, use pure zip.
if len(iterables) < 2:
return zip(*iterables)
# Tail for the first iterable
first_stopped = False
def first_tail():
nonlocal first_stopped
first_stopped = True
return
yield
# Tail for the zip
def zip_tail():
if not first_stopped:
raise ValueError(f'zip_equal: first iterable is longer')
for _ in chain.from_iterable(rest):
raise ValueError(f'zip_equal: first iterable is shorter')
yield
# Put the pieces together
iterables = iter(iterables)
first = chain(next(iterables), first_tail())
rest = list(map(iter, iterables))
return chain(zip(first, *rest), zip_tail())
#-----------------------------------------------------------------------------
# List of solutions to be speedtested
#-----------------------------------------------------------------------------
solutions = [
zip_equal__more_itertools__pylang,
zip_equal__fillvalue__Martijn_Pieters,
zip_equal__chain_raising__cjerdonek,
zip_equal__ziptail__Stefan_Pochmann,
zip,
]
def name(solution):
return solution.__name__[11:] or 'zip'
#-----------------------------------------------------------------------------
# The speedtest code
#-----------------------------------------------------------------------------
def test(m, n):
"""Speedtest all solutions with m iterables of n elements each."""
all_times = {solution: [] for solution in solutions}
def show_title():
print(f'{m} iterators of length {n:,}:')
if verbose: show_title()
def show_times(times, solution):
print(*('%3d ms ' % t for t in times),
name(solution))
for _ in range(3):
for solution in solutions:
times = sorted(timeit.repeat(lambda: deque(solution(*(repeat(i, n) for i in range(m))), 0), number=1, repeat=5))[:3]
times = [round(t * 1e3, 3) for t in times]
all_times[solution].append(times)
if verbose: show_times(times, solution)
if verbose: print()
if verbose:
print('best by min:')
show_title()
for solution in solutions:
show_times(min(all_times[solution], key=min), solution)
print('best by max:')
show_title()
for solution in solutions:
show_times(min(all_times[solution], key=max), solution)
print()
stats.append((m,
[min(all_times[solution], key=min)
for solution in solutions]))
#-----------------------------------------------------------------------------
# Run the speedtest for several numbers of iterables
#-----------------------------------------------------------------------------
stats = []
verbose = False
total_elements = 2 * 10**6
for m in 1, 2, 3, 4, 5, 10, 100, 1000, 10000, 50000, 100000:
test(m, total_elements // m)
#-----------------------------------------------------------------------------
# Print the speedtest results for use in the plotting script
#-----------------------------------------------------------------------------
print('data for plotting by https://replit.com/@pochmann/zipequal-plot')
names = [name(solution) for solution in solutions]
print(f'{names = }')
print(f'{stats = }')
Code for plotting/table (also at Replit):
import matplotlib.pyplot as plt
names = ['more_itertools__pylang', 'fillvalue__Martijn_Pieters', 'chain_raising__cjerdonek', 'ziptail__Stefan_Pochmann', 'zip']
stats = [(1, [[208.762, 211.211, 214.189], [159.568, 162.233, 162.24], [57.668, 58.94, 59.23], [10.418, 10.583, 10.723], [10.057, 10.443, 10.456]]), (2, [[130.065, 130.26, 130.52], [100.314, 101.206, 101.276], [34.405, 34.998, 35.188], [12.152, 12.473, 12.773], [8.671, 8.857, 9.395]]), (3, [[106.417, 107.452, 107.668], [90.693, 91.154, 91.386], [26.908, 27.863, 28.145], [10.457, 10.461, 10.789], [8.071, 8.157, 8.228]]), (4, [[97.547, 98.686, 98.726], [77.076, 78.31, 79.381], [23.134, 23.176, 23.181], [9.321, 9.4, 9.581], [7.541, 7.554, 7.635]]), (5, [[86.393, 88.046, 88.222], [68.633, 69.649, 69.742], [19.845, 20.006, 20.135], [8.726, 8.935, 9.016], [7.201, 7.26, 7.304]]), (10, [[70.384, 71.762, 72.473], [57.87, 58.149, 58.411], [15.808, 16.252, 16.262], [7.568, 7.57, 7.864], [6.732, 6.888, 6.911]]), (100, [[53.108, 54.245, 54.465], [44.436, 44.601, 45.226], [10.502, 11.073, 11.109], [6.721, 6.733, 6.847], [6.753, 6.774, 6.815]]), (1000, [[52.119, 52.476, 53.341], [42.775, 42.808, 43.649], [12.538, 12.853, 12.862], [6.802, 6.971, 7.002], [6.679, 6.724, 6.838]]), (10000, [[54.802, 55.006, 55.187], [45.981, 46.066, 46.735], [34.416, 34.672, 35.009], [9.485, 9.509, 9.626], [9.036, 9.042, 9.112]]), (50000, [[66.681, 66.98, 67.441], [56.593, 57.341, 57.631], [113.988, 114.022, 114.106], [22.088, 22.412, 22.595], [19.412, 19.431, 19.934]]), (100000, [[86.846, 88.111, 88.258], [74.796, 75.431, 75.927], [218.977, 220.182, 223.343], [38.89, 39.385, 39.88], [32.332, 33.117, 33.594]])]
colors = {
'more_itertools__pylang': 'm',
'fillvalue__Martijn_Pieters': 'red',
'chain_raising__cjerdonek': 'gold',
'ziptail__Stefan_Pochmann': 'lime',
'zip': 'black',
}
ns = [n for n, _ in stats]
print('%28s' % 'number of iterables', *('%5d' % n for n in ns))
print('-' * 95)
x = range(len(ns))
for i, name in enumerate(names):
ts = [min(tss[i]) for _, tss in stats]
color = colors[name]
if color:
plt.plot(x, ts, '.-', color=color, label=name)
print('%29s' % name, *('%5.1f' % t for t in ts))
plt.xticks(x, ns, size=9)
plt.ylim(0, 133)
plt.title('zip_equal(m iterables with 2,000,000/m values each)', weight='bold')
plt.xlabel('Number of zipped *iterables* (not their lengths)', weight='bold')
plt.ylabel('Time (for complete iteration) in milliseconds', weight='bold')
plt.legend(loc='upper center')
#plt.show()
plt.savefig('zip_equal_plot.png', dpi=200)