Let's say I have two iterables A = 'ab'
and B = '12'
.
itertools.product
returns an iterator which will generate the cartesian product of A
and B
; e.g.
>>> list(itertools.product(A,B))
[('a', '1'), ('a', '2'), ('b', '1'), ('b', '2')].
The function has an optional keyword argument repeat
, which can be used to find the cartesian product of an iterable with itself; e.g.
>>> list(itertools.product(A,repeat=2))
[('a', 'a'), ('a', 'b'), ('b', 'a'), ('b', 'b')]
and is equivalent to list(itertools.product(A,A))
.
Then using repeat=2
and both A
and B
gives
>>> list(itertools.product(A,B,repeat=2))
[('a', '1', 'a', '1'), ('a', '1', 'a', '2'), ('a', '1', 'b', '1'), ('a', '1', 'b', '2'), ('a', '2', 'a', '1'), ('a', '2', 'a', '2'), ('a', '2', 'b', '1'), ('a', '2', 'b', '2'), ('b', '1', 'a', '1'), ('b', '1', 'a', '2'), ('b', '1', 'b', '1'), ('b', '1', 'b', '2'), ('b', '2', 'a', '1'), ('b', '2', 'a', '2'), ('b', '2', 'b', '1'), ('b', '2', 'b', '2')]
and is equivalent to list(itertools.product(A,B,A,B))
.
But now let's say I want to find the cartesian product of n_A
repetitions of A
and n_B
repetitions of B
, where n_A
and n_B
dont have to be the same. How can I do this? It would be nice if repeat
took the tuple (n_A, n_B)
and I could write
list(itertools.product(A,B,repeat=(n_A,n_B)))
e.g.
list(itertools.product(A,B,repeat=(2,3))) == list(itertools.product(A,A,B,B,B))
but this doesn't appear to be allowed.
Note, rechnically (A,A,B,B,B)
is a different product to (A,B,A,B,B)
, however I'll be sorting the outputs anyway so I don't care about the order of input.
Using tee
to "duplicate" each iterable, then flatten them to a single list of arguments and using * unpacking to pass them to product as individual args
from itertools import product, chain, tee
def myproduct(*iterables, repeat=1):
if isinstance(repeat, int):
return product(iterables, repeat)
assert isinstance(repeat, tuple)
args = chain(*map(tee, iterables, repeat))
return product(*args)
A = 'ab'
B = '12'
n_A = 2
n_B = 3
result = list(product(A, A, B, B, B))
result2 = list(myproduct(A, B, repeat=(n_A, n_B)))
print(result == result2)