Search code examples
c++algorithmperformancememorybiginteger

Find n-element in sequence. How to speed up the programm (time) for n > 10^6?


I think I have right algorithm, but when the values increase to 106 and more, I exceed the MEMORY or TIMELIMIT allowed. At first I tried to push the elements to a vector, then I changed the method to reuse vars and more tests passed.

formula: Ai = (Ai-1 + 2 * Ai-2 + 3 * Ai-3) mod M, where M = 109 + 7.
1 <= n <= 1012 TIMELIMIT: 1 sec, MEMORY: 256mb

Code:

#include<iostream>
#include<cmath>

using namespace std;
using ull = unsigned long long;

ull func(ull n){
    ull a = 1;
    ull b = 1;
    ull c = 2;
    if (n < 2) return a;
    if (n == 3) return c;
    ull res = 0;
    for (ull i = 0; i < n - 3; i++){
        res = (3 * a + 2 * b + c) % (ull)(pow(10, 9) + 7);
        a = b;
        b = c;
        c = res;
    }
    return c;
}

int main() {
    int x; 
    cin >> x;
    cout << func(x);
}

Now I have an algorithm which passes 3 initial tests (and then failed 63 test, where I think values > 10^6)

Test 1 Input: 6 Output: 34

Test 2 Input: 10 Output: 1096

Test 3 Input: 500 Output: 340736120

Do I need to change the algorithm or speed up by any methods?


Solution

  • Your current solution is O(n) which is much too slow when n can be as large as 1012.

    We can find a matrix M such that we can transition from one state to the next by multiplying. M satisfies

    [Ai, Ai-1, Ai-2]T = M * [Ai-1, Ai-2, Ai-3]T

    Clearly, the last row of M is simply [0, 1, 0] to get Ai-2.

    Similarly, the second row is [1, 0, 0].

    The first row is [1, 2, 3], which directly comes from the recurrence relation.

    Now, for n > 3, we can find the nth element of the sequence by (left) multiplying the initial conditions, [A3, A2, A1] = [2, 1, 1], with M a total of n-3 times, then reading off the answer from the first row. This is equivalent to multiplying by Mn-3. Matrix exponentiation can be performed in O(S3 log(N)) where S is the dimension of the matrix (in this case, the constant 3) and N is the exponent with binary exponentiation.

    This leads to the following solution:

    #include <iostream>
    #include <vector>
    #include <span>
    #include <initializer_list>
    #include <stdexcept>
    #include <cstddef>
    constexpr int MOD = 1e9 + 7;
    template<typename T>
    class Matrix {
        std::size_t rows, cols;
        std::vector<std::vector<T>> values;
    
    public:
        Matrix(std::size_t rows, std::size_t cols) : rows{rows}, cols{cols}, values(rows, std::vector<T>(cols)) {}
        Matrix(std::initializer_list<std::initializer_list<T>> initVals) : rows{initVals.size()} {
            values.reserve(rows);
            for (auto& row : initVals) {
                values.emplace_back(row);
                if ((cols = row.size()) != values[0].size()) throw std::domain_error("Not a matrix: rows have unequal size");
            }
        }
        std::span<T> operator[](std::size_t r) {
            return values[r];
        }
        std::span<const T> operator[](std::size_t r) const {
            return values[r];
        }
        static Matrix identity(std::size_t size) {
            Matrix id(size, size);
            for (std::size_t i = 0; i < size; ++i) id.values[i][i] = 1;
            return id;
        }
        Matrix operator*(const Matrix& m) const {
            if (cols != m.rows) throw std::domain_error("Matrix dimensions do not match");
            Matrix res(rows, m.cols);
            for (std::size_t r = 0; r < rows; ++r)
                for (std::size_t c = 0; c < m.cols; ++c)
                    for (std::size_t i = 0; i < cols; ++i)
                        res.values[r][c] += values[r][i] * m.values[i][c];
            return res;
        }
        Matrix operator%(T mod) const {
            auto res = *this;
            for (std::size_t r = 0; r < rows; ++r)
                for (std::size_t c = 0; c < cols; ++c)
                    res.values[r][c] %= mod;
            return res;
        }
        Matrix modPow(std::size_t exp, T mod) const {
            if (rows != cols) throw std::domain_error("Matrix is not square");
            auto res = identity(rows), sq = *this;
            for (; exp; exp >>= 1) {
                if (exp & 1) res = res * sq % mod;
                sq = sq * sq % mod;
            }
            return res;
        }
    };
    const Matrix<unsigned long long> transition{{1, 2, 3}, {1, 0, 0}, {0, 1, 0}}, 
                                     initialConditions{{2}, {1}, {1}};
    unsigned long long nthValue(unsigned long long n){
        if (n < 3) return 1;
        return (transition.modPow(n - 3, MOD) * initialConditions % MOD)[0][0];
    }
    
    int main() {
        unsigned long long n; 
        std::cin >> n;
        std::cout << nthValue(n) << '\n';
    }