I found an interesting problem recently, which looks like this:
There is a dull sorting algorithm which takes the 1st number from an array, it finds an element which is lower by 1 than the 1st element (or it takes the highest element when there is no lower one), and puts it in the front. Cost of putting element with index x (counting from 0) in the front is equal to its index. It continues this process until the array is sorted. The task is to count the cost of sorting all the n! permutations of numbers from 1 to n. The answer might be big so the answer should be modulo m (n and m are given in the input)
Example:
Input (n,m): 3 40
Answer: 15
There are permutations of numbers from 1 to 3. The costs of sorting them are:
(1,2,3)->0
(1,3,2)->5
(2,1,3)->1
(2,3,1)->2
(3,1,2)->4
(3,2,1)->3
sum = 15
My program generates all the possible arrays and sorts them one by one. Its complexity is O(n!*n^2), which is way too high. I am stuck with all my thoughts and this brute force solution.
There are also some funny things I have discovered:
The sorting algorithm has two phases: it first sorts the permutation into some rotation of the identity and then rotates it to the identity. We account for the cost of these phases separately.
The first phase consists of at most n−2 moves. After n−1−j moves, the permutation consists of n−j values x, x+1, x+2, … mod n followed by a permutation of the remaining j values that, assuming that we start from a random permutation, are equally likely to be in any particular order. The expected distance that we have to move x−1 mod n is ((n−j)+(n−1))/2. But hang on, we only count the move if we’re still in the first phase. Thus we need to discount the cases where the permutation is already a rotation. There are n!/j! of them, and they all have x−1 at the end, so the discount for each is n−1.
The second phase consists on average of (n−1)/2 moves from the end of the permutation to the beginning, each costing n−1. The average cost over all n! permutations is thus (n−1)²/2.
I’ll leave the modular arithmetic/strength reduction of the Python below as an exercise.
from itertools import permutations
from math import factorial
# Returns the total cost of sorting all n-element permutations.
def fast_total_sort_cost(n):
cost = 0
for j in range(n - 1, 0, -1):
cost += factorial(n) * ((n - j) + (n - 1)) // 2
cost -= (factorial(n) // factorial(j)) * (n - 1)
return cost + factorial(n) * (n - 1) ** 2 // 2
# Reference implementation and test.
def reference_total_sort_cost(n):
return sum(sort_cost(perm) for perm in permutations(range(n)))
def sort_cost(perm):
cost = 0
perm = list(perm)
while not is_sorted(perm):
i = perm.index((perm[0] - 1) % len(perm))
cost += i
perm.insert(0, perm.pop(i))
return cost
def is_sorted(perm):
return all(perm[i - 1] <= perm[i] for i in range(1, len(perm)))
if __name__ == "__main__":
for m in range(1, 9):
print(fast_total_sort_cost(m), reference_total_sort_cost(m))