Search code examples
algorithmmathalgebranumber-theorymodulo

Determining coefficient of x^m term in (x^2 + x + 1)^n is even or odd


For a given integers n and m, determine that coefficient of x^m term in (x^2+x+1)^n is even or odd?

For example, if n=3 and m=4, (x^2+x+1)^3 = x^6 + 3x^5 + [[6x^4]] + 7x^3 + 6x^2 + 3x + 1, so coefficient of x^4 term is 6 (=even).

n and m is as large as 10^12 and I want to calculate in a few seconds, so you can't calculate in linear time.

Do you have any efficient algorithm?


Solution

  • Yes, linear time in the number of bits in the input.

    The coefficients in question are trinomial coefficients T(n, m). For binomial coefficients, we would use Lucas's theorem; let's work out the trinomial analog for p = 2.

    Working mod 2 and following the proof of Nathan Fine,

    (1 + x + x^2)^{2^i} = 1 + x^{2^i} + x^{2^{2 i}}
    
    (1 + x + x^2)^n
        = prod_i ((1 + x + x^2)^{2^i n_i})
            where n = sum_i n_i 2^i and n_i in {0, 1} for all i
            (i.e., n_i is the binary representation of n
        = prod_i (1 + x^{2^i n_i} + x^{2^i 2 n_i})
        = prod_i sum_{m_i = 0}^{2 n_i} x^{2^i}
        = sum_{(m_i)} prod_i x^{2^i m_i}
            taken over sequences (m_i) where 0 ≤ m_i ≤ 2 n_i.
    

    In the binomial case, the next step is to observe that, for the coefficient of x^m, there's at most one choice of (m_i) whose x^{2^i m_i} factors have the right product, i.e., the binary representation of m.

    In the trinomial case, we have to consider binary pseudo-representations (m_i) of m where pseudo-bits can be zero, one, or two. There is a contribution to the sum if and only if for all i such that n_i = 0, we have m_i = 0.

    We can write an automaton that scans n and m bit by bit. State a is initial and accepting.

    a (0:0:nm') -> a nm'    [emit 0]
    a (1:0:nm') -> a nm'    [emit 0]
                -> b nm'    [emit 2]
    a (1:1:nm') -> a nm'    [emit 1]
    
    b (0:1:nm') -> a nm'    [emit 0]
    b (1:0:nm') -> b nm'    [emit 1]
    b (1:1:nm') -> a nm'    [emit 0]
                -> b nm'    [emit 2]
    

    We can use dynamic programming to count the paths. In code form:

    def trinomial_mod_two(n, m):
        a, b = 1, 0
        while m:
            n1, n = n & 1, n >> 1
            m1, m = m & 1, m >> 1
            if n1:
                if m1:
                    a, b = a ^ b, b
                else:
                    a, b = a, a ^ b
            elif m1:
                a, b = b, 0
            else:
                a, b = a, 0
        return a
    

    Branchless version for giggles:

    def trinomial_mod_two_branchless(n, m):
        a, b = 1, 0
        while m:
            n1, n = n & 1, n >> 1
            m1, m = m & 1, m >> 1
            a, b = ((n1 | ~m1) & a) ^ (m1 & b), ((n1 & ~m1) & a) ^ (n1 & b)
        return a