Search code examples
pythonnumpycombinationspython-itertools

How to make itertools combinations 'increase' evenly?


Consider the following example:

import itertools
import numpy as np

a = np.arange(0,5)
b = np.arange(0,3)
c = np.arange(0,7)

prods = itertools.product(a,b,c)

for p in prods:
    print(p)

This iterate over the products in the following order:

(0, 0, 0)
(0, 0, 1)
(0, 0, 2)
(0, 0, 3)
(0, 0, 4)
(0, 1, 0)

But I would much rather have the products given in order of their sum, e.g.

(0, 0, 0)
(0, 0, 1)
(0, 1, 0)
(1, 0, 0)
(0, 1, 1)
(1, 0, 1)
(1, 1, 0)
(0, 0, 2)

How can I achieve this without storing all combinations in memory?

Note: a b and c are always ranges, but not necessarily with the same maximum. There is also no 2nd-level ordering when the sums of two products are equal, i.e. (0,1,1) is equivalent to (2,0,0).


Solution

  • The easiest way to do this without storing extra products in memory is with recursion. Instead of range(a,b), pass in a list of (a,b) pairs and do the iteration yourself:

    def prod_by_sum(range_bounds: List[Tuple[int, int]]):
        """
        Yield from the Cartesian product of input ranges, produced in order of sum.
    
        >>> range_bounds = [(2, 4), (3, 6), (0, 2)]
        >>> for prod in prod_by_sum(range_bounds):
        ...    print(prod)
        (2, 3, 0)
        (2, 3, 1)
        (2, 4, 0)
        (3, 3, 0)
        (2, 4, 1)
        (2, 5, 0)
        (3, 3, 1)
        (3, 4, 0)
        (2, 5, 1)
        (3, 4, 1)
        (3, 5, 0)
        (3, 5, 1)
    
        """
        def prod_by_sum_helper(start: int, goal_sum: int):
            low, high = range_bounds[start]
            if start == len(range_bounds) - 1:
                if low <= goal_sum < high:
                    yield (goal_sum,)
                return
    
            for current in range(low, min(high, goal_sum + 1)):
                yield from ((current,) + extra
                            for extra in prod_by_sum_helper(start + 1, goal_sum - current))
    
        lowest_sum = sum(lo for lo, hi in range_bounds)
        highest_sum = sum(hi - 1 for lo, hi in range_bounds)
    
        for goal_sum in range(lowest_sum, highest_sum + 1):
            yield from prod_by_sum_helper(0, goal_sum)
    

    which has output for range_bounds = [(0, 5), (0, 3), (0, 7)] starting with:

    (0, 0, 0)
    (0, 0, 1)
    (0, 1, 0)
    (1, 0, 0)
    (0, 0, 2)
    (0, 1, 1)
    (0, 2, 0)
    (1, 0, 1)
    (1, 1, 0)
    (2, 0, 0)
    

    You can do this exact process iteratively by modifying a single list and yielding copies of it, but the code either becomes more complicated or less efficient.

    You can also trivially modify this to support steps besides 1, however that does work less efficiently with larger and larger steps since the last range might not contain the element needed to produce the current sum. That seems unavoidable, because at that point you'd need to solve a difficult computational problem to efficiently loop over those products by sum.