Search code examples
pythonpermutationpython-itertoolscircular-permutations

Permutations without cycles


I want to generate all possible permutations of a list, where cyclic permutations (going from left to right) should only occur once.

Here is an example:

Let the list be [A, B, C]. Then I want to have permutations such as [A, C, B] but not [B, C, A] as this would be a circular permutation of the original list [A, B, C]. For the list above, the result should look like

[A, B, C]
[A, C, B]
[B, A, C]
[C, B, A]

Here is a minimal working example that uses permutations() from itertools.

from itertools import permutations


def permutations_without_cycles(seq: list):
    # Get a list of all permutations
    permutations_all = list(permutations(seq))

    print("\nAll permutations:")
    for i, p in enumerate(permutations_all):
        print(i, "\t", p)

    # Get a list of all cyclic permutations
    cyclic_permutations = [tuple(seq[i:] + seq[:i]) for i in range(len(seq))]

    print("\nAll cyclic permutations:")
    for i, p in enumerate(cyclic_permutations):
        print(i, "\t", p)

    # Remove all cyclic permutations except for one
    cyclic_permutations = cyclic_permutations[1:]  # keep one cycle
    permutations_cleaned = [p for p in permutations_all if p not in cyclic_permutations]

    print("\nCleaned permutations:")
    for i, item in enumerate(permutations_cleaned):
        print(i, "\t", item)


def main():
    seq = ["A", "B", "C"]
    permutations_without_cycles(seq=seq)


if __name__ == "__main__":
    main()

I would like to know if there is a method in itertools to solve this problem for efficiently?


Solution

  • That's unusual, so no, that's not already in itertools. But we can optimize your way significantly (mainly by filtering out the unwanted cyclics by using a set instead of a list, or even by just the single next unwanted one). Even more efficiently, we can compute the indexes of the unwanted permutations[*] and islice between them. See the full code at the bottom.

    [*] Using a simplified version of permutation_index from more-itertools.

    Benchmark results, using list(range(n)) as the sequence. Ints compare fairly quickly, so if the sequence elements were some objects with more expensive comparisons, my efficient solution would have an even bigger advantage, since it's the only one that doesn't rely on comparing permutations/elements.

    8 elements:
      1.76 ±  0.07 ms  efficient
      3.60 ±  0.76 ms  optimized_iter
      4.65 ±  0.81 ms  optimized_takewhile
      4.97 ±  0.43 ms  optimized_set
      8.19 ±  0.31 ms  optimized_generator
     21.42 ±  1.19 ms  original
    
    9 elements:
     13.11 ±  2.39 ms  efficient
     34.37 ±  2.83 ms  optimized_iter
     40.87 ±  4.49 ms  optimized_takewhile
     46.74 ±  2.27 ms  optimized_set
     78.79 ±  3.43 ms  optimized_generator
    237.72 ±  5.76 ms  original
    
    10 elements:
    160.61 ±  4.58 ms  efficient
    370.79 ± 14.71 ms  optimized_iter
    492.95 ±  2.45 ms  optimized_takewhile
    565.04 ±  9.68 ms  optimized_set
             too slow  optimized_generator
             too slow  original
    

    Code (Attempt This Online!):

    from itertools import permutations, chain, islice, filterfalse, takewhile
    from timeit import timeit
    from statistics import mean, stdev
    from collections import deque
    
    # Your original, just without the prints/comments, and returning the result
    def original(seq: list):
        permutations_all = list(permutations(seq))
        cyclic_permutations = [tuple(seq[i:] + seq[:i]) for i in range(len(seq))]
        cyclic_permutations = cyclic_permutations[1:]
        permutations_cleaned = [p for p in permutations_all if p not in cyclic_permutations]
        return permutations_cleaned
    
    
    # Your original with several optimizations
    def optimized_set(seq: list): 
        cyclic_permutations = {tuple(seq[i:] + seq[:i]) for i in range(1, len(seq))}
        return filterfalse(cyclic_permutations.__contains__, permutations(seq))
    
    
    # Further optimized to filter by just the single next unwanted permutation
    def optimized_iter(seq: list):
        def parts():
            it = permutations(seq)
            yield next(it),
            for i in range(1, len(seq)):
                skip = tuple(seq[i:] + seq[:i])
                yield iter(it.__next__, skip)
            yield it
        return chain.from_iterable(parts())
    
    
    # Another way to filter by just the single next unwanted permutation
    def optimized_takewhile(seq: list):
        def parts():
            it = permutations(seq)
            yield next(it),
            for i in range(1, len(seq)):
                skip = tuple(seq[i:] + seq[:i])
                yield takewhile(skip.__ne__, it)
            yield it
        return chain.from_iterable(parts())
    
    
    # Yet another way to filter by just the single next unwanted permutation
    def optimized_generator(seq: list):
        perms = permutations(seq)
        yield next(perms)
        for i in range(1, len(seq)):
            skip = tuple(seq[i:] + seq[:i])
            for perm in perms:
                if perm == skip:
                    break
                yield perm
        yield from perms
    
    
    # Compute the indexes of the unwanted permutations and islice between them
    def efficient(seq):
        def parts():
            perms = permutations(seq)
            yield next(perms),
            perms_index = 1
            n = len(seq)
            for rotation in range(1, n):
                index = 0
                for i in range(n, 1, -1):
                    index = index * i + rotation * (i > rotation)
                yield islice(perms, index - perms_index)
                next(perms)
                perms_index = index + 1
            yield perms
        return chain.from_iterable(parts())
    
    
    funcs = original, optimized_generator, optimized_set, optimized_iter, optimized_takewhile, efficient
    
    
    #--- Correctness checks
    
    seq = ["A", "B", "C"]
    for f in funcs:
        print(*f(seq), f.__name__)
    
    seq = 3,1,4,5,9,2,6
    for f in funcs:
        assert list(f(seq)) == original(seq)
    
    for n in range(9):
        seq = list(range(n))
        for f in funcs:
            assert list(f(seq)) == original(seq)
    
    
    #--- Speed tests
    
    def test(seq, funcs):
        print()
        print(len(seq), 'elements:')
    
        times = {f: [] for f in funcs}
        def stats(f):
            ts = [t * 1e3 for t in sorted(times[f])[:5]]
            return f'{mean(ts):6.2f} ± {stdev(ts):5.2f} ms '
    
        for _ in range(25):
            for f in funcs:
                t = timeit(lambda: deque(f(seq), 0), number=1)
                times[f].append(t)
    
        for f in sorted(funcs, key=stats):
            print(stats(f), f.__name__)
    
    test(list(range(8)), funcs)
    test(list(range(9)), funcs)
    test(list(range(10)), funcs[2:])