Search code examples
algorithmmathpascal

Finding sum of geometric sequence with modulo 10^9+7 with my program


The problem is given as: Output the answer of (A^1+A^2+A^3+...+A^K) modulo 1,000,000,007, where 1≤ A, K ≤ 10^9, and A and K must be an integer.

I am trying to write a program to compute the above question. I have tried using the formula for geometric sequence, then applying the modulo on the answer. Since the results must be an integer as well, finding modulo inverse is not required.

Below is the code I have now, its in pascal

Var
a,k,i:longint;
power,sum: int64;
Begin
    Readln(a,k);
    power := 1;
    For i := 1 to k do
    power := ((power mod 1000000007) * a) mod 1000000007;
    sum := a * (power-1) div (a-1);
    Writeln(sum mod 1000000007);
End.

This task came from my school, they do not give away their test data to the students. Hence I do not know why or where my program is wrong. I only know that my program outputs the wrong answer for their test data.


Solution

  • If you want to do this without calculating a modular inverse, you can calculate it recursively using:

    1 + A + A2 + A3 + … + Ak

    = 1 + (A + A2)(1 + A2 + (A2)2 + … + (A2)k/2−1)

    That’s for even k. For odd k:

    1 + A + A2 + A3 + ... + Ak

    = (1 + A)(1 + A2 + (A2)2 + ... + (A2)(k−1)/2)

    Since k is divided by 2 in each recursive call, the resulting algorithm has O(log k) complexity. In Java:

    static int modSumAtoAk(int A, int k, int mod)
    {
        return (modSum1ToAk(A, k, mod) + mod-1) % mod;
    }
    
    static int modSum1ToAk(int A, int k, int mod)
    {
        long sum;
        if (k < 5) {
            //k is small -- just iterate
            sum = 0;
            long x = 1;
            for (int i=0; i<=k; ++i) {
                sum = (sum+x) % mod;
                x = (x*A) % mod;
            }
            return (int)sum;
        }
        //k is big
        int A2 = (int)( ((long)A)*A % mod );
        if ((k%2)==0) {
            // k even
            sum = modSum1ToAk(A2, (k/2)-1, mod);
            sum = (sum + sum*A) % mod;
            sum = ((sum * A) + 1) % mod;
        } else {
            // k odd
            sum = modSum1ToAk(A2, (k-1)/2, mod);
            sum = (sum + sum*A) % mod;
        }
        return (int)sum;
    }
    

    Note that I’ve been very careful to make sure that each product is done in 64 bits, and to reduce by the modulus after each one.

    With a little math, the above can be converted to an iterative version that doesn’t require any storage:

    static int modSumAtoAk(int A, int k, int mod)
    {
        // first, we calculate the sum of all 1... A^k
        // we'll refer to that as SUM1 in comments below
    
        long fac=1;
        long add=0;
    
        //INVARIANT: SUM1 = add + fac*(sum 1...A^k)
        //this will remain true as we change k
    
        while (k > 0) {
            //above INVARIANT is true here, too
    
            long newmul, newadd;
            if ((k%2)==0) {
                //k is even.  sum 1...A^k = 1+A*(sum 1...A^(k-1))
                newmul = A;
                newadd = 1;
                k-=1;
            } else {
                //k is odd.
                newmul = A+1L;
                newadd = 0;
                A = (int)(((long)A) * A % mod);
                k = (k-1)/2;
            }
            //SUM1 = add + fac * (newadd + newmul*(sum 1...Ak))
            //     = add+fac*newadd + fac*newmul*(sum 1...Ak)
    
            add = (add+fac*newadd) % mod;
            fac = (fac*newmul) % mod;
    
            //INVARIANT is restored
        }
    
        // k == 0
        long sum1 = fac + add;
        return (int)((sum1 + mod -1) % mod);
    }