Search code examples
pythonpython-3.xiterator

Sorting an iterator in python


I want to iterate over a big itertools product, but I want to do it in a different order from the one that product offers. The problem is that sorting an iterator using sorted takes time. For example:

from itertools import product
import time

RNG = 15
RPT = 6

start = time.time()
a = sorted(product(range(RNG), repeat=RPT), key=sum)
print("Sorted: " + str(time.time() - start))
print(type(a))

start = time.time()
a = product(range(RNG), repeat=RPT)
print("Unsorted: " + str(time.time() - start))
print(type(a))

Creating the sorted iterator takes about twice as long. I'm guessing this is because sorted actually involves going through the whole iterator and returning a list. Whereas the second unsorted iterator is doing some sort of lazy evaluation magic.

I guess there's really two questions here.

  1. General question: is there a lazy evaluation way to change the order items appear in an iterator?
  2. Specific question: is there a way to loop through all m-length lists of ints less than n, hitting lists with smaller sums first?

Solution

  • If your objective is to reduce memory consumption, you could write your own generator to return the permutations in order of their sum (see below). But, if memory is not a concern, sorting the output of itertools.product() will be faster than the Python code that produces the same result.

    Writing a recursive function that produces the combinations of values in order of their sum can be achieved by merging multiple iterators (one per starting value) based on the smallest sum:

    def sumCombo(A,N):
        if N==1:
            yield from ((n,) for n in A) # single item combos
            return
        pA = []                          # list of iterator/states
        for i,n in enumerate(A):         # for each starting value 
            ip = sumCombo(A[i:],N-1)     # iterator recursion to N-1
            p  = next(ip)                # current N-1 combination
            pA.append((n+sum(p),p,n,ip)) # sum, state & iterator
        while pA:
            # index and states of smallest sum
            i,(s,p,n,ip) = min(enumerate(pA),key=lambda ip:ip[1][0])
            ps = s
            while s == ps:        # output equal sum combinations
               yield (n,*p)       # yield starting number with recursed
               p = next(ip,None)  # advance iterator
               if p is None:
                   del pA[i]      # remove exhausted iterators
                   break
               s = n+sum(p)       # compute new sum
               pA[i] = (s,p,n,ip) # and update states
    

    This will only produce combinations of values as opposed to the product which produces distinct permutations of these combinations. (38,760 combinations vs 11,390,625 products).

    In order to obtain all the products, you would need to run these combinations through a function that generates distinct permutations:

    def permuteDistinct(A):
        if len(A) == 1:
            yield tuple(A) # single value
            return
        seen = set()               # track starting value
        for i,n in enumerate(A):   # for each starting value
            if n in seen: continue # not yet used
            seen.add(n)
            for p in permuteDistinct(A[:i]+A[i+1:]): 
                yield (n,*p)       # starting value & rest
    
    def sumProd(A,N):     
        for p in sumCombo(A,N):           # combinations in order of sum
            yield from permuteDistinct(p) # permuted
    

    So sumProd(range(RNG),RPT) will produce the 11,390,625 permutations in order of their sum, without storing them in a list BUT it will take 5 times longer to do so (compared to sorting the product).

    a = sorted(product(range(RNG), repeat=RPT), key=sum) # 4.6 sec
    b = list(sumProd(range(RNG),RPT))                    # 23  sec
    
    list(map(sum,a)) == list(map(sum,b)) # True  (same order of sums)
    a == b                               # False (order differs for equal sums)
    
    a[5:15]            b[5:15]             sum
    (0, 1, 0, 0, 0, 0) (0, 1, 0, 0, 0, 0)  1
    (1, 0, 0, 0, 0, 0) (1, 0, 0, 0, 0, 0)  1
    (0, 0, 0, 0, 0, 2) (0, 0, 0, 0, 0, 2)  2
    (0, 0, 0, 0, 1, 1) (0, 0, 0, 0, 2, 0)  2
    (0, 0, 0, 0, 2, 0) (0, 0, 0, 2, 0, 0)  2
    (0, 0, 0, 1, 0, 1) (0, 0, 2, 0, 0, 0)  2
    (0, 0, 0, 1, 1, 0) (0, 2, 0, 0, 0, 0)  2
    (0, 0, 0, 2, 0, 0) (2, 0, 0, 0, 0, 0)  2
    (0, 0, 1, 0, 0, 1) (0, 0, 0, 0, 1, 1)  2
    (0, 0, 1, 0, 1, 0) (0, 0, 0, 1, 0, 1)  2
    

    If your process is searching for specific sums, it may be interesting to filter on combinations first and only expand distinct permutations for the combinations (sums) that meet your criteria. This could potentially cut down the number of iterations considerably (sumCombo(range(RNG),RPT) # 0.22 sec is faster than sorting the products).