Search code examples
pythonpython-3.xpython-itertools

Carry out itertools.product allowing for a different number of repetitions for different iterables


Let's say I have two iterables A = 'ab' and B = '12'.

itertools.product returns an iterator which will generate the cartesian product of A and B; e.g.

>>> list(itertools.product(A,B)) 
[('a', '1'), ('a', '2'), ('b', '1'), ('b', '2')].

The function has an optional keyword argument repeat, which can be used to find the cartesian product of an iterable with itself; e.g.

>>> list(itertools.product(A,repeat=2))
[('a', 'a'), ('a', 'b'), ('b', 'a'), ('b', 'b')]

and is equivalent to list(itertools.product(A,A)).

Then using repeat=2 and both A and B gives

>>> list(itertools.product(A,B,repeat=2))
[('a', '1', 'a', '1'), ('a', '1', 'a', '2'), ('a', '1', 'b', '1'), ('a', '1', 'b', '2'), ('a', '2', 'a', '1'), ('a', '2', 'a', '2'), ('a', '2', 'b', '1'), ('a', '2', 'b', '2'), ('b', '1', 'a', '1'), ('b', '1', 'a', '2'), ('b', '1', 'b', '1'), ('b', '1', 'b', '2'), ('b', '2', 'a', '1'), ('b', '2', 'a', '2'), ('b', '2', 'b', '1'), ('b', '2', 'b', '2')]

and is equivalent to list(itertools.product(A,B,A,B)).

But now let's say I want to find the cartesian product of n_A repetitions of A and n_B repetitions of B, where n_A and n_B dont have to be the same. How can I do this? It would be nice if repeat took the tuple (n_A, n_B) and I could write

list(itertools.product(A,B,repeat=(n_A,n_B)))

e.g.

list(itertools.product(A,B,repeat=(2,3))) == list(itertools.product(A,A,B,B,B))

but this doesn't appear to be allowed.

Note, rechnically (A,A,B,B,B) is a different product to (A,B,A,B,B), however I'll be sorting the outputs anyway so I don't care about the order of input.


Solution

  • Using tee to "duplicate" each iterable, then flatten them to a single list of arguments and using * unpacking to pass them to product as individual args

    from itertools import product, chain, tee
    
    def myproduct(*iterables, repeat=1):
        if isinstance(repeat, int):
            return product(iterables, repeat)
        assert isinstance(repeat, tuple)
        args = chain(*map(tee, iterables, repeat))
        return product(*args)
    
    
    A = 'ab'
    B = '12'
    
    n_A = 2
    n_B = 3
    
    
    result = list(product(A, A, B, B, B))
    result2 = list(myproduct(A, B, repeat=(n_A, n_B)))
    
    print(result == result2)