Search code examples
c++algorithmbit-manipulationmodulointeger-overflow

Modulo Multiplication Function: Multiplying two integers under a modulus


I came across this modulo multiplication function in a code for the miller-rabin primality test. This is supposed to eliminate the integer overflow that occurs when calculating ( a * b ) % m.

I need some help in understanding what is going on here. Why does this work? and what is the significance of that number literal 0x8000000000000000ULL?

unsigned long long mul_mod(unsigned long long a, unsigned long long b, unsigned long long m) {
    unsigned long long d = 0, mp2 = m >> 1;
    if (a >= m) a %= m;
    if (b >= m) b %= m;
    for (int i = 0; i < 64; i++)
    {
        d = (d > mp2) ? (d << 1) - m : d << 1;
        if (a & 0x8000000000000000ULL)
            d += b;
        if (d >= m) d -= m;
        a <<= 1;
    }
    return d;
}

Solution

  • This code, which currently appears on the modular arithmetic Wikipedia page, only works for arguments of up to 63 bits -- see bottom.

    Overview

    One way to compute an ordinary multiplication a * b is to add left-shifted copies of b -- one for each 1-bit in a. This is similar to how most of us did long multiplication in school, but simplified: Since we only ever need to "multiply" each copy of b by 1 or 0, all we need to do is either add the shifted copy of b (when the corresponding bit of a is 1) or do nothing (when it's 0).

    This code does something similar. However, to avoid overflow (mostly; see below), instead of shifting each copy of b and then adding it to the total, it adds an unshifted copy of b to the total, and relies on later left-shifts performed on the total to shift it into the correct place. You can think of these shifts "acting on" all the summands added to the total so far. For example, the first loop iteration checks whether the highest bit of a, namely bit 63, is 1 (that's what a & 0x8000000000000000ULL does), and if so adds an unshifted copy of b to the total; by the time the loop completes, the previous line of code will have shifted the total d left 1 bit 63 more times.

    The main advantage of doing it this way is that we are always adding two numbers (namely b and d) that we already know are less than m, so handling the modulo wraparound is cheap: We know that b + d < 2 * m, so to ensure that our total so far remains less than m, it suffices to check whether b + d < m, and if not, subtract m. If we were to use the shift-then-add approach instead, we would need a % modulo operation per bit, which is as expensive as division -- and usually much more expensive than subtraction.

    One of the properties of modulo arithmetic is that, whenever we want to perform a sequence of arithmetic operations modulo some number m, performing them all in usual arithmetic and taking the remainder modulo m at the end always yields the same result as taking remainders modulo m for each intermediate result (provided no overflows occur).

    Code

    Before the first line of the loop body, we have the invariants d < m and b < m.

    The line

    d = (d > mp2) ? (d << 1) - m : d << 1;
    

    is a careful way of shifting the total d left by 1 bit, while keeping it in the range 0 .. m and avoiding overflow. Instead of first shifting it and then testing whether the result is m or greater, we test whether it is currently strictly above RoundDown(m/2) -- because if so, after doubling, it will surely be strictly above 2 * RoundDown(m/2) >= m - 1, and so require a subtraction of m to get back in range. Note that even though the (d << 1) in (d << 1) - m may overflow and lose the top bit of d, this does no harm as it does not affect the lowest 64 bits of the subtraction result, which are the only ones we are interested in. (Also note that if d == m/2 exactly, we wind up with d == m afterward, which is slightly out of range -- but changing the test from d > mp2 to d >= mp2 to fix this would break the case where m is odd and d == RoundDown(m/2), so we have to live with this. It doesn't matter, because it will be fixed up below.)

    Why not simply write d <<= 1; if (d >= m) d -= m; instead? Suppose that, in infinite-precision arithmetic, d << 1 >= m, so we should perform the subtraction -- but the high bit of d is on and the rest of d << 1 is less than m: In this case, the initial shift will lose the high bit and the if will fail to execute.

    Restriction to inputs of 63 bits or fewer

    The above edge case can only occur when d's high bit is on, which can only occur when m's high bit is also on (since we maintain the invariant d < m). So it looks like the code is taking pains to work correctly even with very high values of m. Unfortunately, it turns out that it can still overflow elsewhere, resulting in incorrect answers for some inputs that set the top bit. For example, when a = 3, b = 0x7FFFFFFFFFFFFFFFULL and m = 0xFFFFFFFFFFFFFFFFULL, the correct answer should be 0x7FFFFFFFFFFFFFFEULL, but the code will return 0x7FFFFFFFFFFFFFFDULL (an easy way to see the correct answer is to rerun with the values of a and b swapped). Specifically, this behaviour occurs whenever the line d += b overflows and leaves the truncated d less than m, causing a subtraction to be erroneously skipped.

    Provided this behaviour is documented (as it is on the Wikipedia page), this is just a limitation, not a bug.

    Removing the restriction

    If we replace the lines

        if (a & 0x8000000000000000ULL)
            d += b;
        if (d >= m) d -= m;
    

    with

        unsigned long long x = -(a >> 63) & b;
        if (d >= m - x) d -= m;
        d += x;
    

    the code will work for all inputs, including those with top bits set. The cryptic first line is just a conditional-free (and thus usually faster) way of writing

        unsigned long long x = (a & 0x8000000000000000ULL) ? b : 0;
    

    The test d >= m - x operates on d before it has been modified -- it's like the old d >= m test, but b (when the top bit of a is on) or 0 (otherwise) has been subtracted from both sides. This tests whether d would be m or larger once x is added to it. We know that the RHS m - x never underflows, because the largest x can be is b and we have established that b < m at the top of the function.