Search code examples
pythonmathcombinationsdynamic-programmingpermutation

Given a rod of length N , you need to cut it into R pieces , such that each piece's length is positive, how many ways are there to do so?


Description:

Given two positive integers N and R, how many different ways are there to cut a rod of length N into R pieces, such that the length of each piece is a positive integer? Output this answer modulo 1,000,000,007.

Example:

With N = 7 and R = 3, there are 15 ways to cut a rod of length 7 into 3 pieces: (1,1,5) , (1,5,1), (5,1,1) , (1,2,4) , (1,4,2) (2,1,4), (2,4,1) , (4,1,2), (4,2,1) , (1,3,3), (3,1,3), (3,3,1), (2,2,3), (2,3,2), (3,2,2).

Constraints:

1 <= R <= N <= 200,000

Testcases:

 N    R       Output
 7    3           15
36    6       324632
81   66    770289477
96   88    550930798

My approach:

I know that the answer is (N-1 choose R-1) mod 1000000007. I have tried all different ways to calculate it, but always 7 out of 10 test cases went time limit exceeded. Here is my code, can anyone tell me what other approach I can use to make it in O(1) time complexity.

from math import factorial

def new(n, r):
    D = factorial(n - 1) // (factorial(r - 1) * factorial(n - r))
    return (D % 1000000007)

if __name__ == '__main__':
    N = [7, 36, 81, 96]
    R = [3, 6, 66, 88]
    answer = [new(n, r) for n,r in zip(N,R)]
    print(answer)

Solution

  • I think there's two big optimizations that the problem is looking for you to exploit. The first being to cache intermediate values of factorial() to save computational effort across large batches (large T). The second optimization being to reduce your value mod 1000000007 incrementally, so your numbers stay small, and multiplication stays a constant-time. I've updated the below example to precompute a factorial table using a custom function and itertools.accumulate, instead of merely caching the calls in a recursive implementation (which will eliminate the issues with recursion depth you were seeing).

    from itertools import accumulate
    
    MOD_BASE = 1000000007
    N_BOUND = 200000
    
    def modmul(m):
        def mul(x, y):
            return x * y % m
        return mul
        
    FACTORIALS = [1] + list(accumulate(range(1, N_BOUND+1), modmul(MOD_BASE)))
    
    def nck(n, k, m):
        numerator = FACTORIALS[n]
        denominator = FACTORIALS[k] * FACTORIALS[n-k]
        return numerator * pow(denominator, -1, m) % m
    
    def solve(n, k):
        return nck(n-1, k-1, MOD_BASE)
    

    Running this against the example:

    >>> pairs = [(36, 6), (81, 66), (96, 88)]
    >>> print([solve(n, k) for n, k in pairs])
    [324632, 770289477, 550930798]