Search code examples
pythonpython-itertools

Set itertools product repeat value per element


If I run the code

import itertools

products = itertools.product([0,1],repeat=3)
print(list(products))

I get the following output:

[(0, 0, 0), (0, 0, 1), (0, 1, 0), (0, 1, 1), (1, 0, 0), (1, 0, 1), (1, 1, 0), (1, 1, 1)]

However, I would only like to repeat the 0 once and the 1 twice. In other words, I want the following output:

[(0, 1, 1), (1, 0, 1), (1, 1, 0)]

How can I achieve this?


Of course, I could do:

import itertools

products = itertools.permutations([0,1,1],3)
print(list(set(products)))

but in my case there is a large number of elements so that calling set before iteration will kill the code due to memory issues.


Solution

  • If you only have 0's and 1's this would work:

    from itertools import combinations
    
    def gen(n0, n1):
        n = n0 + n1
        for c in combinations(range(n), n1):
            out = [0]*n
            for i in c:
                out[i]=1
            yield out
    
    list(gen(1,2))
    

    The way to build out may not be the most optimal, but the idea is there, I'll leave it to you to improve on it if time efficiency is an issue.

    Generalizing one step further:

    def gen(n0, n1, n2):
        n12 = n1 + n2
        n = n0 + n12
        for c12 in combinations(range(n), n12):
            out = [0]*n
            for i in c12:
                out[i] = 1
            for c2 in combinations(c12, n2):
                out_ = out.copy()
                for i in c2:
                    out_[i] = 2
                yield out_
    

    Again the construction of out_ is likely suboptimal. And with the same idea you can keep nesting to more and more different elements. And if you have more possible elements that the depth becomes cumbersome to nest by hand, you can recursify the process, which is a fun exercise too:

    def gen(ns, elems=None, C=None, out=None):
        
        if elems is None:
            elems = list(range(len(ns)))
        else:
            assert len(elems) == len(ns)
            
        if out is None:
            N = 1
            for n in ns:
                N *= n
            out = [elems[0]]*N
            C = range(N)
        
        if len(ns) == 1:
            yield out
        
        else:
            n = ns[-1]
            e = elems[-1]
            
            for c in combinations(C,n):
                out_ = out.copy()
                for i in c:
                    out_[i] = e
                C_ = [i for i in C if i not in c]
                yield from gen(ns[:-1], elems[:-1], C_, out_)