Search code examples
pythonperformancepermutation

Efficiently count lists with certain properties


My purpose is to count permutations with certain properties. I first generate the permutations and then remove those that do not satisfy the desired properties. How could I improve the code to be able to enumerate more permutations? I am currently able to handle permutations of 14 elements.

Edit: Following suggestions by Jérôme Richard and attempt0, I have refactored the code to use generators. Am I doing it right?

from itertools import permutations

def check(seq, verbose=False):
    """check that the elements of the sequence equal a difference of previous elements"""
    n = len(seq)
    for k in range(1, n-1):
        # build a set of admissible values
        dk = {abs(seq[i]-seq[j]) for i in range(0, k) for j in range(i+1, k+1) if i < j}
        if k > 0 and verbose:
            print('current index = ', k)
            print('current subsequence = ', seq[:k+1])
            print('current admissible values = ', dk)
            print('next element = ', seq[k+1])
        # check if the next element is in the set of admissible values
        if k > 0 and seq[k+1] not in dk:
            # return an invalid subsequence (k+2 to include the invalid element)
            return seq[:k+2]
    return seq

def is_valid(seq):
    """check that the sequence satisfies certain properties"""
    n = len(seq)
    if n < 3:
        return False
    if len(check(seq)) == n:
        return True
    return False

def filter_perms(perms):
    for perm in perms:
       if is_valid(perm): yield perm

def make_perms(n):
    """The elements of the list are integers, where a list of length n stores all integers from 1 to n."""
    for p in permutations(range(1,n-1)):
        yield (n,) + p + (n-1,)

def enumerate_perms(n):
    perms = make_perms(n)
    return filter_perms(perms)

# testing a good sequence
seq=(5, 2, 3, 1, 4)
check(seq, verbose=True)
is_valid(seq)
# True

# testing a bad sequence
seq=[5, 2, 1, 3, 4]
check(seq, verbose=True)
is_valid(seq)
# False

# testing permutations
tuple(make_perms(5))

# testing enumeration
tuple(enumerate_perms(5))
# ((5, 2, 3, 1, 4), (5, 3, 2, 1, 4))

len(tuple(enumerate_perms(14)))
# 29340

My initial questions (partly answered in the comments): Should I look into using numpy arrays? Should I look into using a generator? Should I save the permutations to a database?


Solution

  • Filtering is a good approach in terms of implementation simplicity. However, the problem is that it has to iterate through all possible permutations, which is impossible when N becomes too large.

    For example, for N=20, you are going to iterate over approx 6 quadrillion elements. It will take several days just to iterate over them. This means, instead of filtering, it is necessary to prevent the generation of unwanted elements (it is called pruning).

    Since you have already commented about this idea, I am going to skip the explanation and just show you how to implement it. (I'll attach the full code, including the benchmark, at the bottom, so please use that if you want to try it out.)

    def _filtered_permutations(
        items: Sequence[int],
        not_used_indexes: set[int],
        current_permutation: list[int],
        # The condition function is configurable to make this function more versatile.
        is_valid: Callable[[Sequence[int]], bool],
    ) -> Iterable[list[int]]:
        if len(current_permutation) == len(items):
            # Since we picked a valid element below, there is no need to validate it here.
            yield current_permutation
        else:
            for i in not_used_indexes:
                # This is the in-progress validation for the partial permutation.
                # If it is found to violate the conditions, we skip the further search.
                # This greatly reduces the generation of unwanted elements.
                next_permutation = [*current_permutation, items[i]]
                if not is_valid(next_permutation):
                    continue
    
                yield from _filtered_permutations(
                    items=items,
                    not_used_indexes=not_used_indexes - {i},
                    current_permutation=next_permutation,
                    is_valid=is_valid,
                )
    
    
    def is_valid_original(seq: Sequence[int]) -> bool:
        # This is equivalent to your original is_valid function, except I've removed the unnecessary list creation.
        for i in range(len(seq) - 2):
            if seq[i + 2] not in {abs(seq[j] - seq[k]) for j in range(1, i + 2) for k in range(i + 1)}:
                return False
        return True
    
    
    def enumerate_perms_pre_filter(n: int) -> Iterable[tuple[int]]:
        lst = list(range(1, n - 1))
        head = n
        tail = n - 1
    
        def _is_valid(current):
            """A wrapper function that inserts the first and last elements."""
            if len(current) == len(lst):
                return is_valid_original((head, *current, tail))
            return is_valid_original((head, *current))
    
        for p in _filtered_permutations(
            items=lst,
            not_used_indexes=set(range(len(lst))),
            current_permutation=[],
            is_valid=_is_valid,
        ):
            yield head, *p, tail
    

    This ran 20-30 times faster than the original code on my PC.

    At this point, the bottleneck is that the is_valid function is calculating the differences each time. I'm not sure if I can make this versatile enough, so here I will provide an implementation specific to your problem.

    This may look similar to the implementation above, but achieves a significant speedup by decomposing the in_valid function and carrying the differences instead.

    def _enumerate_perms_optimized(
        items: Sequence[int],
        current_permutation: list[int],
        not_used_indexes: set[int],
        current_diffs: set[int],
    ) -> Iterator[tuple[int]]:
        if len(current_permutation) == len(items) + 1:
            # Since we picked a valid element below, there is no need to validate it here.
            yield current_permutation
        else:
            for i in not_used_indexes:
                # By carrying the differences, we can validate it by simply checking whether the set contains the value.
                next_value = items[i]
                if len(current_permutation) > 1 and next_value not in current_diffs:
                    continue
    
                yield from _enumerate_perms_optimized(
                    items=items,
                    not_used_indexes=not_used_indexes - {i},
                    current_permutation=[*current_permutation, next_value],
                    # The differences between the elements that have already been added have already been calculated,
                    # so we only need to calculate the differences between the new element and them.
                    current_diffs={*current_diffs, *(abs(next_value - prev_value) for prev_value in current_permutation)},
                )
    
    
    def enumerate_perms_optimized(n: int) -> Iterator[tuple[int]]:
        lst = list(range(1, n - 1))
        head = n
        tail = n - 1
    
        for p in _enumerate_perms_optimized(
            items=lst,
            not_used_indexes=set(range(len(lst))),
            current_permutation=[head],
            current_diffs=set(),
        ):
            yield *p, tail
    

    Here is the test and the benchmark.

    import time
    from collections.abc import Iterator, Sequence
    from itertools import permutations
    from typing import Callable, Iterable
    
    
    # ---------- Original implementation ----------
    def is_valid_original(seq: Sequence[int]) -> bool:
        # This is equivalent to your original is_valid function, except I've removed the unnecessary list creation.
        for i in range(len(seq) - 2):
            if seq[i + 2] not in {abs(seq[j] - seq[k]) for j in range(1, i + 2) for k in range(i + 1)}:
                return False
        return True
    
    
    def enumerate_perms_1(n: int) -> list[tuple[int]]:
        lst = list(range(1, n - 1))
        perms = [p for perm in permutations(lst) for p in [(n, *perm, n - 1)] if is_valid_original(p)]
        return perms
    
    
    # ---------- Updated implementation ----------
    def check(seq, verbose=False):
        """check that the elements of the sequence equal a difference of previous elements"""
        n = len(seq)
        for k in range(1, n - 1):
            # build a set of admissible values
            dk = {abs(seq[i] - seq[j]) for i in range(0, k) for j in range(i + 1, k + 1) if i < j}
            if k > 0 and verbose:
                print("current index = ", k)
                print("current subsequence = ", seq[: k + 1])
                print("current admissible values = ", dk)
                print("next element = ", seq[k + 1])
            # check if the next element is in the set of admissible values
            if k > 0 and seq[k + 1] not in dk:
                # return an invalid subsequence (k+2 to include the invalid element)
                return seq[: k + 2]
        return seq
    
    
    def is_valid(seq):
        """check that the sequence satisfies certain properties"""
        n = len(seq)
        if n < 3:
            return False
        if len(check(seq)) == n:
            return True
        return False
    
    
    def filter_perms(perms):
        for perm in perms:
            if is_valid(perm):
                yield perm
    
    
    def make_perms(n):
        """The elements of the list are integers, where a list of length n stores all integers from 1 to n."""
        for p in permutations(range(1, n - 1)):
            yield (n,) + p + (n - 1,)
    
    
    def enumerate_perms_2(n):
        perms = make_perms(n)
        return filter_perms(perms)
    
    
    # ---------- Pre-filtered implementation ----------
    def _filtered_permutations(
        items: Sequence[int],
        not_used_indexes: set[int],
        current_permutation: list[int],
        # The condition function is configurable to make this function more versatile.
        is_valid: Callable[[Sequence[int]], bool],
    ) -> Iterable[list[int]]:
        if len(current_permutation) == len(items):
            # Since we picked a valid element below, there is no need to validate it here.
            yield current_permutation
        else:
            for i in not_used_indexes:
                # This is the in-progress validation for the partial permutation.
                # If it is found to violate the conditions, we skip the further search.
                # This greatly reduces the generation of unwanted elements.
                next_permutation = [*current_permutation, items[i]]
                if not is_valid(next_permutation):
                    continue
    
                yield from _filtered_permutations(
                    items=items,
                    not_used_indexes=not_used_indexes - {i},
                    current_permutation=next_permutation,
                    is_valid=is_valid,
                )
    
    
    def enumerate_perms_pre_filter(n: int) -> Iterable[tuple[int]]:
        lst = list(range(1, n - 1))
        head = n
        tail = n - 1
    
        def _is_valid(current):
            """Wrapper function that inserts the first and last elements."""
            if len(current) == len(lst):
                return is_valid_original((head, *current, tail))
            return is_valid_original((head, *current))
    
        for p in _filtered_permutations(
            items=lst,
            not_used_indexes=set(range(len(lst))),
            current_permutation=[],
            is_valid=_is_valid,
        ):
            yield head, *p, tail
    
    
    # ---------- Optimized implementation ----------
    def _enumerate_perms_optimized(
        items: Sequence[int],
        current_permutation: list[int],
        not_used_indexes: set[int],
        current_diffs: set[int],
    ) -> Iterator[tuple[int]]:
        if len(current_permutation) == len(items) + 1:
            # Since we picked a valid element below, there is no need to validate it here.
            yield current_permutation
        else:
            for i in not_used_indexes:
                # By carrying the differences, we can validate it by simply checking whether the set contains the value.
                next_value = items[i]
                if len(current_permutation) > 1 and next_value not in current_diffs:
                    continue
    
                yield from _enumerate_perms_optimized(
                    items=items,
                    not_used_indexes=not_used_indexes - {i},
                    current_permutation=[*current_permutation, next_value],
                    # The differences between the elements that have already been added have already been calculated,
                    # so we only need to calculate the differences between the new element and them.
                    current_diffs={*current_diffs, *(abs(next_value - prev_value) for prev_value in current_permutation)},
                )
    
    
    def enumerate_perms_optimized(n: int) -> Iterator[tuple[int]]:
        lst = list(range(1, n - 1))
        head = n
        tail = n - 1
    
        for p in _enumerate_perms_optimized(
            items=lst,
            not_used_indexes=set(range(len(lst))),
            current_permutation=[head],
            current_diffs=set(),
        ):
            yield *p, tail
    
    
    def test_implementations(candidates, n: int):
        expected = sorted(candidates[0](n))
        expected = [tuple(p) for p in expected]
        for f in candidates[1:]:
            actual = sorted(f(n))
            actual = [tuple(p) for p in actual]
            assert expected == actual, f"Results differ for n={n} with {f.__name__}"
    
    
    def measure_performance(f, n_range: range):
        for n in n_range:
            n_perms = 0
            started = time.perf_counter()
            for _ in f(n):
                n_perms += 1
            elapsed = time.perf_counter() - started
            print(f"{f.__name__}({n=}): {elapsed:.3f} sec, {n_perms=:,}")
    
    
    def main():
        candidates = [
            enumerate_perms_1,
            enumerate_perms_2,
            enumerate_perms_pre_filter,
            enumerate_perms_optimized,
        ]
        for n in range(3, 10):
            test_implementations(candidates, n)
            print(f"Tests passed for {n=}")
    
        measure_performance(enumerate_perms_1, range(10, 14))
        measure_performance(enumerate_perms_2, range(10, 14))
        measure_performance(enumerate_perms_pre_filter, range(10, 16))
        measure_performance(enumerate_perms_optimized, range(10, 21))
    
    
    main()
    

    Result:

    Tests passed for n=3
    Tests passed for n=4
    Tests passed for n=5
    Tests passed for n=6
    Tests passed for n=7
    Tests passed for n=8
    Tests passed for n=9
    enumerate_perms_1(n=10): 0.029 sec, n_perms=36
    enumerate_perms_1(n=11): 0.268 sec, n_perms=598
    enumerate_perms_1(n=12): 2.413 sec, n_perms=1,096
    enumerate_perms_1(n=13): 27.832 sec, n_perms=14,030
    enumerate_perms_2(n=10): 0.035 sec, n_perms=36
    enumerate_perms_2(n=11): 0.327 sec, n_perms=598
    enumerate_perms_2(n=12): 3.098 sec, n_perms=1,096
    enumerate_perms_2(n=13): 34.459 sec, n_perms=14,030
    enumerate_perms_pre_filter(n=10): 0.001 sec, n_perms=36
    enumerate_perms_pre_filter(n=11): 0.025 sec, n_perms=598
    enumerate_perms_pre_filter(n=12): 0.058 sec, n_perms=1,096
    enumerate_perms_pre_filter(n=13): 0.982 sec, n_perms=14,030
    enumerate_perms_pre_filter(n=14): 2.595 sec, n_perms=29,340
    enumerate_perms_pre_filter(n=15): 24.978 sec, n_perms=223,350
    enumerate_perms_optimized(n=10): 0.000 sec, n_perms=36
    enumerate_perms_optimized(n=11): 0.003 sec, n_perms=598
    enumerate_perms_optimized(n=12): 0.005 sec, n_perms=1,096
    enumerate_perms_optimized(n=13): 0.064 sec, n_perms=14,030
    enumerate_perms_optimized(n=14): 0.138 sec, n_perms=29,340
    enumerate_perms_optimized(n=15): 1.095 sec, n_perms=223,350
    enumerate_perms_optimized(n=16): 9.790 sec, n_perms=1,936,172
    enumerate_perms_optimized(n=17): 164.876 sec, n_perms=28,038,794
    enumerate_perms_optimized(n=18): 542.482 sec, n_perms=90,125,652
    

    As you can see, it is hundreds of times faster than the original implementation, but it will still take several hours if not days for n=20.

    Finally, I would also like to mention the options that were mentioned in the comments.

    All of the above codes are PyPy-compatible, so if you install PyPy, you can run them without any modifications, and they will probably be several times faster, but that's all. You can't expect it to be any faster than that from PyPy.

    Cython may be more effective, but as you know, you need to learn its syntax, so you may find it difficult to learn.

    Numba does not work well with this approach. Numba does not support recursive-generator functions. Some major changes are needed. I'm not sure if we can get a good performance without significantly complicating the code.