Search code examples
algorithmcombinatoricsnumber-theorypolynomials

Sum of products of all subsets,


I want to calculate sum of product of every subset of given $N$ element set. For example, given set {1, 2, 3}, the answer is 1 + 2 + 3 + 1 * 2 + 1 * 3 + 2 * 3 + 1 * 2 * 3. I also would like to give the answer modulo $M$.

What I know is that I could calculate $(x - a_1)(x - a_2)...(x - a_n) - 1$, but that would involve FFT, so there could be some rounding errors, but the main problem with that idea is that it takes $O(N \log^2 N)$ time and doing modulo $M$ is problematic. Is there any faster way to solve this problem? It is not my homework, I got this task from my teacher to practice to the programming contest but I really got stuck on this problem.

Constraints:

$a_i \le 10^9$

$N \le 10^6$

$M \le 10^9$


Solution

  • The sum in question is

    [(1+a_1)*(1+a_2)*(1+a_3)*...*(1+a_N) - 1] (mod M)
    

    This is

    [(1+a_1)%M * (1+a_2)%M * ... * (1+a_N)%M - 1] % M
    

    I would be surprised if you could do much better.

    Here is a Python implementation:

    def sumProducts(nums, M):
        p = 1
        for num in nums:
            p = p*((1+num)%M)%M
            if p == 0:
                return M-1
        return (p-1)%M
    

    The optimizations from the naïve formula I gave above were to take the modulus of the product with each new factor and to short-circuit the product if a zero is encountered -- which will happen if the prime factors (counting accorded to multiplicity) appear in the (1 + a_i)

    A simple test:

    >>> sumProducts([1,2,3],5)
    3
    

    which is easily verified by hand.

    A stress-test:

    >>> from random import randint
    >>> nums = [randint(1,1000000) for i in range(100000)]
    

    nums is a million random numbers in range 1 to a million

    of course,

    >>> sumProducts(nums,2**32)
    4294967295
    

    since there are at least 32 odd numbers in nums (hence 32 numbers a_i for which 1+a_i is even).

    on the other hand, 1000003 is a prime number which is greater than 1000000, so the computation doesn't short-circuit:

    >>> sumProducts(nums,1000003)
    719694
    

    The computation takes than a second.