Search code examples
pythonpermutationpython-itertools

partitioned permutations in python


I wish to have a list of permutations on n elements [0,1,2,...], for n = n1 + n2 + n3. But such permutations are partitioned into m partitions.

For example for n1,n2 = 3,2 I would have:

0,1,2 | 3,4
0,1,2 | 4,3
0,2,1 | 3,4
0,2,1 | 4,3
...
2,1,0 | 4,3

if I use itertools:

product(permutations([0,1,2]),permutations([3,4]))

I get:

[((0, 1, 2), (3, 4)), ((0, 1, 2), (4, 3)), ((0, 2, 1), (3, 4)), ((0, 2, 1), (4, 3)), ((1, 0, 2), (3, 4)), ((1, 0, 2), (4, 3)), ((1, 2, 0), (3, 4)), ((1, 2, 0), (4, 3)), ((2, 0, 1), (3, 4)), ((2, 0, 1), (4, 3)), ((2, 1, 0), (3, 4)), ((2, 1, 0), (4, 3))]

But I would like:

[(0, 1, 2, 3, 4), (0, 1, 2, 4, 3), ...]

Also it would be great if the input could simply be the length of the partitions:

input = [3,2]
or
input = [4,3,2]

In the latter case I would get:

[(0,1,2,3,  4,5,6,  7,8),
 (0,1,2,3,  4,5,6,  8,7),
 (0,1,2,3,  4,6,5,  7,8),
 ...]

Any ideas?


Solution

  • As I understand the problem, the following code should work for your needs.

    from itertools import permutations, product
    
    
    def part_perm_iter(ns):
        inds = [int(sum(ns[:i])) for i in xrange(len(ns)+1)]
        pair_inds = zip(inds,inds[1:])
    
        for p in product( *[permutations(xrange(a,b)) for a, b in pair_inds ] ):
            yield sum(p,())
    

    For example, print list(part_perm_iter([2,2])) will print:

    [(0, 1, 2, 3), (0, 1, 3, 2), (1, 0, 2, 3), (1, 0, 3, 2)]