Search code examples
pythonlistpython-3.xpython-itertoolscartesian-product

Cartesian product that returns outputs of varied lengths


So I have these lists:

a = [1, 2, 3]
b = [11, 12, 13, 14]
c = [21, 22, 23, 24, 25, 26]

I want to get all possible combinations (duplicates are fine) consisting of 2 elements from a, 3 elements from b and 3 elements from c. Like this:

([1, 2], [11, 12, 13], [21, 22, 23]) # 1
([1, 2], [11, 12, 13], [22, 23, 24]) # 2
# all the way to...
([2, 3], [12, 13, 14], [24, 25, 26]) # 16

If I use itertools.product(), it only gives me 1 from each list:

import itertools

def cartesian(the_list):
    for i in itertools.product(*the_list):
        yield i

a = [1, 2, 3]
b = [11, 12, 13, 14]
c = [21, 22, 23, 24, 25, 26]

test = cartesian([a, b, c])

print(next(test)) 
# Gives (1, 11, 21). But I need ([1, 2], [11, 12, 13], [21, 22, 23])

print(next(test)) 
# Gives (1, 11, 22). But I need ([1, 2], [11, 12, 13], [22, 23, 24])

I could use multiple nested for loops, but I would need too many loops if I have a lot of lists.

So how do I implement an algorithm which gives me all possible combinations, with each combination consisting of a certain number of elements from each input lists?


Solution

  • Build a generator function which can yield as many values as you want and use it in the product, like this

    >>> from itertools import product
    >>> def get_chunks(items, number=3):
    ...     for i in range(len(items) - number + 1): 
    ...         yield items[i: i + number]
    ...     
    ... 
    

    and then define your cartesian generator, like this

    >>> def cartesian(a, b, c):
    ...     for items in product(get_chunks(a, 2), get_chunks(b), get_chunks(c)):
    ...         yield items
    ...     
    ... 
    

    If you are using Python 3.3+, you can actually use yield from here, like this

    >>> def cartesian(a, b, c):
    ...     yield from product(get_chunks(a, 2), get_chunks(b), get_chunks(c))
    ... 
    

    And then, when you get all the elements as a list, you will get

    >>> from pprint import pprint
    >>> pprint(list(cartesian([1, 2, 3],[11, 12, 13, 14],[21, 22, 23, 24, 25, 26])))
    [([1, 2], [11, 12, 13], [21, 22, 23]),
     ([1, 2], [11, 12, 13], [22, 23, 24]),
     ([1, 2], [11, 12, 13], [23, 24, 25]),
     ([1, 2], [11, 12, 13], [24, 25, 26]),
     ([1, 2], [12, 13, 14], [21, 22, 23]),
     ([1, 2], [12, 13, 14], [22, 23, 24]),
     ([1, 2], [12, 13, 14], [23, 24, 25]),
     ([1, 2], [12, 13, 14], [24, 25, 26]),
     ([2, 3], [11, 12, 13], [21, 22, 23]),
     ([2, 3], [11, 12, 13], [22, 23, 24]),
     ([2, 3], [11, 12, 13], [23, 24, 25]),
     ([2, 3], [11, 12, 13], [24, 25, 26]),
     ([2, 3], [12, 13, 14], [21, 22, 23]),
     ([2, 3], [12, 13, 14], [22, 23, 24]),
     ([2, 3], [12, 13, 14], [23, 24, 25]),
     ([2, 3], [12, 13, 14], [24, 25, 26])]