Search code examples
c++algorithmperformancecryptographyprimes

Fast Algorithm for Modular Multiplicatiion


I was trying to implement a large prime number generator, and the average time it takes to generate a 2048 bit length prime number is about 40s. I see from the analysis of call stack that the majority of the time (99%) was taken by modular multiplication, and performance changes very much changing this algorithm. I'm using boost::multiprecision::cpp_int or a type uint2048_t analogous to boost::multiprecision::uint1024_t. This are the two algorithms I used, where the first (don't know why) is a lot faster than second. The first (that works only with boost::multiprecision integer), that I use, is a very trivial algorithm to calculate modular multiplication, and by the way inside this the majority of the time is taken from modulo operation.

template <typename T>
T mulMod(const T& x, const T& y, const T& p) {
    using boost::multiprecision::cpp_int;
    cpp_int rv = (cpp_int{x} * cpp_int{y}) % cpp_int{p};
    return T{rv};
}
template <typename T>
T mulMod(T a, T b, const T& p) {
    T rv = 0;
    a %= p;
    while (b > 0) {
        if (b & 1)
            rv = (rv + a) % p;
        a = (a << 1) % p;
        b >>= 1;
    }
    return rv;
}

Are there any faster algorithm, maybe implemented in C++, to execute modular multiplication?


Solution

  • You started out saying you want to generate prime numbers. But you did not mention the connection between mod multiply and primes.

    Knuth Volume 2 has lots of material on bignum arithmetic and finding large prime numbers.

    A comment mentions Montgomery modular arithmetic. Here is a link https://en.wikipedia.org/wiki/Montgomery_modular_multiplication

    OpenSSL has the BN (bignum) package which includes Montgomery multiply and large prime number generation.

    Gnu Multi Precision (gmp) library has similar routines.

    Your second mulMod() routine can be optimized. When it does mod p in the loop, the argument is no greater than 2*p so the mod can be done like this if( arg >= p) arg -= p.