Search code examples
rustinteger-overflowinteger-arithmeticmodular-arithmetic

How to do arithmetic modulo another number, without overflow?


I'm trying to implement a fast primality test for Rust's u32 and u64 datatypes. As part of it, I need to compute (n*n)%d where n and d are u32 (or u64, respectively).

While the result can easily fit in the datatype, I'm at a loss for how to compute this. As far as I know there is no processor primitive for this.

For u32 we can fake it -- cast up to u64, so that the product won't overflow, then take the modulus, then cast back down to u32, knowing this won't overflow. However since I don't have a u128 datatype (as far as I know) this trick won't work for u64.

So for u64, the most obvious way I can think of to accomplish this is to somehow compute x*y to get a pair (carry, product) of u64, so we capture the amount of overflow instead of just losing it (or panicking, or whatever).

Is there a way to do this? Or another standard way to solve the problem?


Solution

  • Richard Rast pointed out that Wikipedia version works only with 63-bit integers. I extended the code provided by Boiethios to work with full range of 64-bit unsigned integers.

    fn mul_mod64(mut x: u64, mut y: u64, m: u64) -> u64 {
        let msb = 0x8000_0000_0000_0000;
        let mut d = 0;
        let mp2 = m >> 1;
        x %= m;
        y %= m;
    
        if m & msb == 0 {
            for _ in 0..64 {
                d = if d > mp2 {
                    (d << 1) - m
                } else {
                    d << 1
                };
                if x & msb != 0 {
                    d += y;
                }
                if d >= m {
                    d -= m;
                }
                x <<= 1;
            }
            d
        } else {
            for _ in 0..64 {
                d = if d > mp2 {
                    d.wrapping_shl(1).wrapping_sub(m)
                } else {
                    // the case d == m && x == 0 is taken care of 
                    // after the end of the loop
                    d << 1
                };
                if x & msb != 0 {
                    let (mut d1, overflow) = d.overflowing_add(y);
                    if overflow {
                        d1 = d1.wrapping_sub(m);
                    }
                    d = if d1 >= m { d1 - m } else { d1 };
                }
                x <<= 1;
            }
            if d >= m { d - m } else { d }
        }
    }
    
    #[test]
    fn test_mul_mod64() {
        let half = 1 << 16;
        let max = std::u64::MAX;
    
        assert_eq!(mul_mod64(0, 0, 2), 0);
        assert_eq!(mul_mod64(1, 0, 2), 0);
        assert_eq!(mul_mod64(0, 1, 2), 0);
        assert_eq!(mul_mod64(1, 1, 2), 1);
        assert_eq!(mul_mod64(42, 1, 2), 0);
        assert_eq!(mul_mod64(1, 42, 2), 0);
        assert_eq!(mul_mod64(42, 42, 2), 0);
        assert_eq!(mul_mod64(42, 42, 42), 0);
        assert_eq!(mul_mod64(42, 42, 41), 1);
        assert_eq!(mul_mod64(1239876, 2948635, 234897), 163320);
    
        assert_eq!(mul_mod64(1239876, 2948635, half), 18476);
        assert_eq!(mul_mod64(half, half, half), 0);
        assert_eq!(mul_mod64(half+1, half+1, half), 1);
    
        assert_eq!(mul_mod64(max, max, max), 0);
        assert_eq!(mul_mod64(1239876, 2948635, max), 3655941769260);
        assert_eq!(mul_mod64(1239876, max, max), 0);
        assert_eq!(mul_mod64(1239876, max-1, max), max-1239876);
        assert_eq!(mul_mod64(max, 2948635, max), 0);
        assert_eq!(mul_mod64(max-1, 2948635, max), max-2948635);
        assert_eq!(mul_mod64(max-1, max-1, max), 1);
        assert_eq!(mul_mod64(2, max/2, max-1), 0);
    }