Search code examples
c++algorithmcryptographyrsamontgomery-multiplication

Bizzare identical incorrect results across different MWR2MM algorithms for RSA montgomery multiplication


Background

I'm trying to implement RSA 2048 in hardware (xilinx ZYNQ FPGA) using a variety of different Montgomery methods. I am implementing the algorithm using Xilinx HLS (essentially C++ code that is synthesized into hardware).

Note: For the sake of this post, treat it just like a standard C++ implementation, except that I can have variables that act like bit-vectors up to 4096 bits wide, and access individual bits using foo[bit] or foo.range(7,0) syntax. I haven't yet parallelized it, so it should behave just like standard C++ code. Please don't be afraid and stop reading because I said the word FPGA and HLS. Just treat this one like C++ code.

I've been able to get a working prototype that uses the standard square-and-multiply for modular exponentiation, and the standard radix-2 MM algorithm for the modular multiplication, however it takes up too much space on the FPGA and I need to use less resource-heavy algorithms.

To save space, I'm trying to implement the Tenka-koc Scalable Multiple Word Radix 2 Montgomery Multiplication (MWR2MM) proposed here. I've been struggling with it for a month but to no avail. However there is an interesting problem resulting out of my struggles that I cannot figure out.

The Problem

My issue is that MWR2MM does not return the correct answer when performing the Montgomery multiplication. However, I am beginning to think that it is not a coding error, and rather, that I am instead just misinterpreting something critical about the usage of the algorithm.

There are multiple variations of the MWR2MM algorithm, with quite differing implementations, and I've tried to implement many of them. I currently have 4 different implementations of MWR2MM coded up, all based on modifications to the algorithm put forth in a number of papers. What makes me think that my implementations are actually correct, is that all of these varying versions of the algorithm are returning the same INCORRECT answer! I don't think this is coincidence, but I also don't think the published algorithms are wrong....Therefore, I posit that something more nefarious is actually going on, and my algorithm implementations are correct.

Example 1

For example, take the original proposed MWR2MM put forth in tenca-koc's paper, which we refer to as MWR2MM_CSA because the algorithm's addition operations all use a carry-save adder (CSA) when implemented in hardware.

  • S is the partial sum
  • M is the modulus
  • Y is the multiplicand
  • X is the multiplier and x_i (subscript) is a single bit (e.g. X = (x_n,...,x_1,x_0).
  • The superscript are the words vectors ( e.g. M = (0,M^{e-1},...,M^1,M^0)
  • (A,B) is the concatenation of two bit vectors.
  • m is the operands width
  • w is the width of choosen words
  • e is the number of w-bit words required to complete the vectors, (e.g. e = ceil((m+1)/w) )

enter image description here

My implementation of this algorithm uses the following parameters:

  • MWR2MM_m = 2048 (operand size, m from above)
  • MWR2MM_w = 8 (word size, w from above)
  • MWR2MM_e = ceil( (e+1)/w ) = 257 (number of words + 1 per operand, e from above)
  • ap_uint<NUM_BITS> is how you declare a bit vector in HLS

My code:

void mwr2mm_csa( ap_uint<MWR2MM_m> X,
                 ap_uint<MWR2MM_w> Y[MWR2MM_e+1],
                 ap_uint<MWR2MM_w> M[MWR2MM_e+1],
                 ap_uint<MWR2MM_m> *out)
{
    // Declare and zero partial sum S
    ap_uint<MWR2MM_w> S[MWR2MM_e] = 0;
    for (int i=0; i<MWR2MM_e; i++)
        S[i] = 0;

    // Two Carry bits
    ap_uint<1> Ca=0, Cb=0;

    for (int i=0; i<MWR2MM_m; i++)
    {
        (Ca,S[0]) = X[i]*Y[0] + S[0]; // this is how HLS concatenates vectors, just like in the paper!
        if (S[0][0] == 1) // if the 0th bit of the 0th word is 1
        {
            (Cb,S[0]) = S[0] + M[0];
            for (int j=1; j<=MWR2MM_e; j++)
            {   
                (Ca, S[j]) = Ca + X[i]*Y[j] + S[j];
                (Cb, S[j]) = Cb + M[j] + S[j];
                S[j-1] = ( S[j][0], S[j-1].range(MWR2MM_w-1,1) );
            }
        }
        else
        {
            for (int j=1; j<=MWR2MM_e; j++)
            {
                (Ca, S[j]) = Ca + X[i]*Y[j] + S[j];
                S[j-1] = ( S[j][0], S[j-1].range(MWR2MM_w-1,1) );
            }
        }
    }

    // copy the result to the output pointer
    for (int i=0; i<MWR2MM_e-1; i++)
        out->range(MWR2MM_w*i+(MWR2MM_w-1), MWR2MM_w*i) = S[i].to_uchar();
}

Now, it is my understanding that (quoting from the paper above)

the Montgomery Multiplication (MM) algorithm on two integers X and Y , with required parameters for n bits of precision, will result in the number MM(X,Y,M) = XY(2^-n) (modulo m), where r=2^n and M is an integer in the range (2^(n-1), 2^(n)) such that gcd(r,M)=1. Since r=2^n , it is sufficient that the modulus M be an odd integer.

Therefore, we should expect the following results (verified w/ software library):

X = 0xABA5E025B607AA14F7F1B8CC88D6EC01C2D17C536508E7FA10114C9437D9616C9E1C689A4FC54744FA7DFE66D6C2FCF86E332BFD6195C13FE9E331148013987A947D9556A27A326A36C84FB38BFEFA0A0FFA2E121600A4B6AA4F9AD2F43FB1D5D3EB5EABA13D3B382FED0677DF30A089869E4E93943E913D0DC099AA320B8D8325B2FC5A5718B19254775917ED48A34E86324ADBC8549228B5C7BEEEFA86D27A44CEB204BE6F315B138A52EC714888C8A699F6000D1CD5AB9BF261373A5F14DA1F568BE70A0C97C2C3EFF0F73F7EBD47B521184DC3CA932C91022BF86DD029D21C660C7C6440D3A3AE799097642F0507DFAECAC11C2BD6941CBC66CEDEEAB744
Y = 0xD091BE9D9A4E98A172BD721C4BC50AC3F47DAA31522DB869EB6F98197E63535636C8A6F0BA2FD4C154C762738FBC7B38BDD441C5B9A43B347C5B65CFDEF4DCD355E5E6F538EFBB1CC161693FA2171B639A2967BEA0E3F5E429D991FE1F4DE802D2A1D600702E7D517B82BFFE393E090A41F57E966A394D34297842552E15550B387E0E485D81C8CCCAAD488B2C07A1E83193CE757FE00F3252E4BD670668B1728D73830F7AE7D1A4C02E7AFD913B3F011782422F6DE4ED0EF913A3A261176A7D922E65428AE7AAA2497BB75BFC52084EF9F74190D0D24D581EB0B3DAC6B5E44596881200B2CE5D0FB2831D65F036D8E30D5F42BECAB3A956D277E3510DF8CBA9
M = 0xD27BF9F01E2A901DB957879F45F697330D21A21095DA4FA7D3AAB75454A8E9F0F4EA531ECE34F0C3BA9E02EB27D8F0DBE78EEDE4AC84061BEEF162D00B55C0DD772D28F23E994899AA19B9BEA7B12A8027A32A92190A3630E249544675488121565A23548FCD36F5382EEB993DB9CE3F526F20AB355E82D963D59541BC1161E211A03E3B372560840C57E12BD2F40EAC5FFCEC01B3F07C378C0A60B74BEF7B572764C88A4F98B61FA8CCD905AFAE779E6193378304D8EB17695CE71A173AC3DE11271753C48DB58546E5AF9917C1CEBBA5BB1AF3FCE3DF9516C0C95C9BC14BB65D1C53078C06C81AC0F3ED0D8634260E47BF780CF4F4996084DF732935194417
MM(X,Y,M) = 0x444682CC199679928F5971191ACCB8EAA5C76CF743E54FC28FD8DCFF57BD198677A26A5C1A6254810A91049FA85CBE3EDDFDCDF12ED3FBB204DE249C389CDEE3FA6DB65441AFE03F1148660EA0E756E038891CEF098F2A009FB443685202FAC40D8FE7B82A1F643020EA31F5A8F4B253AD2F30028C59F1E2DCF3902BBC48E73ECA7BDC22BB92E8A70BC535584BF644CAF24EF39A1899F18C05937446AACC5C64762AFAD2B73EEDF3AA96C9A4CFF836A551A26AED46279328EDD4B9BBBC182B9E408640D058926882B3A0FAA043F726EF96E07B7960D586E2648534EB15C23FE152D0D088F1742E023715E3ABAEC8128B51CC86E8BC207D69F1E6BA7067D44429

But instead, my algorithm returns

MWR2MM_csa(X,Y,M) = 0x16C27CBC37C109B048B0F8B860C3501DB2E90F07D9BF9F6A63839453AC6603776C8CBD7AE8974544C52F078AD035AF1AC58CBBD5DB5801CDF3CF876C43F29FC1719ADF46804928D8BB621FCD48988160602C47812299603181FD97AEC74B7BE563EA0B0CB9EC9B2559191D8EE6AE8092FF9E50ADC1B874BC40C9256D785A4920DC1C1A5DF2B8492B181D16841EEA5377524BDF9BCC8A6DC3919DD4FDF6BBD7BB9D8FC35D06D7A4135363A2AA7FA6AE43B335A2704B007E405731A0D5D352EF7C51AD58241D201E07FA86AA395BB8F5AB3C9B966D5DB966777B45FE47B1838B97AFED23907D7AF61CF809D0B934FC3899998BFEF5B11516CA76C62D999CED8840

Example 2

Ok, so maybe that implementation was wrong. Lets try another modified version, the MWR2MM_CPA algorithm (named for the carry-propogate adders used in hardware): enter image description here

And my implementation of MWR2MM_CSA:

void mwr2mm_cpa(rsaSize_t X, rsaSize_t Yin, rsaSize_t Min, rsaSize_t* out)
{
// extend operands to 2 extra words longer
ap_uint<MWR2MM_m+2*MWR2MM_w> Y = Yin; 
ap_uint<MWR2MM_m+2*MWR2MM_w> M = Min;
ap_uint<MWR2MM_m+2*MWR2MM_w> S = 0;

ap_uint<2> C = 0;
bit_t qi = 0;

// unlike the previous example, we store the concatenations in a temporary variable
ap_uint<10> temp_concat=0; 

for (int i=0; i<MWR2MM_m; i++)
{
    qi = (X[i]*Y[0]) xor S[0];

    // C gets top two bits of temp_concat, j'th word of S gets bottom 8 bits of temp_concat
    temp_concat = X[i]*Y.range(MWR2MM_w-1,0) + qi*M.range(MWR2MM_w-1,0) + S.range(MWR2MM_w-1,0);
    C = temp_concat.range(9,8);
    S.range(MWR2MM_w-1,0) = temp_concat.range(7,0);

    for (int j=1; j<=MWR2MM_e; j++)
    {
        temp_concat = C + X[i]*Y.range(MWR2MM_w*j+(MWR2MM_w-1), MWR2MM_w*j) + qi*M.range(MWR2MM_w*j+(MWR2MM_w-1), MWR2MM_w*j) + S.range(MWR2MM_w*j+(MWR2MM_w-1), MWR2MM_w*j);
        C = temp_concat.range(9,8);
        S.range(MWR2MM_w*j+(MWR2MM_w-1), MWR2MM_w*j) = temp_concat.range(7,0);

        S.range(MWR2MM_w*(j-1)+(MWR2MM_w-1), MWR2MM_w*(j-1)) = (S.bit(MWR2MM_w*j), S.range( MWR2MM_w*(j-1)+(MWR2MM_w-1), MWR2MM_w*(j-1)+1));
    }
    S.range(S.length()-1, S.length()-MWR2MM_w) = 0;
    C=0;
}

*out = S;

}

When run with the same X,Y and M, this too returns the exact same incorrect result as MWR2MM_CSA, despite different bit-level operations.

MWR2MM_cpa(X,Y,M) = 0x16C27CBC37C109B048B0F8B860C3501DB2E90F07D9BF9F6A63839453AC6603776C8CBD7AE8974544C52F078AD035AF1AC58CBBD5DB5801CDF3CF876C43F29FC1719ADF46804928D8BB621FCD48988160602C47812299603181FD97AEC74B7BE563EA0B0CB9EC9B2559191D8EE6AE8092FF9E50ADC1B874BC40C9256D785A4920DC1C1A5DF2B8492B181D16841EEA5377524BDF9BCC8A6DC3919DD4FDF6BBD7BB9D8FC35D06D7A4135363A2AA7FA6AE43B335A2704B007E405731A0D5D352EF7C51AD58241D201E07FA86AA395BB8F5AB3C9B966D5DB966777B45FE47B1838B97AFED23907D7AF61CF809D0B934FC3899998BFEF5B11516CA76C62D999CED8840

For brevity, I will spare you the two other algorithms that also return the same incorrect result. I should note that both of these algorithms work correctly when used with a 4-bit operand size and 2-bit word size. However any other operand size/word size combinations are incorrect, but have the same incorrect result for all four differing bit-level implementations.

I cannot for the life of me figure out why all four algorithms return the same incorrect result. My code in the first example literally is word-for-word identical to the algorithm put forth in the tenca-koc paper!

Am I incorrect at assuming that the MWR2MM algorithm should return the same result (in the montgomery domain) as the standard radix-2 MM algorithm? They have the same radix, so the results should be identical regardless of word-size. Should I not able to interchange these algorithms with each other?

sorry for the lengthy post, but I want to be very precise and coherent in explaining what the issue is. I am not asking for help debugging my code, but rather trying to figure out whether I am misunderstanding an underlying feature of the montgomery multiplication algorithms. Also curious why different implementations are giving the same WRONG result.

Thanks!


Solution

  • The issue is that your algorithm actually returns:

    0x116c27cbc37...
      ^
    

    which is greater than M. If you subtract M from this you get the expected answer:

    Both algorithms return a value in the range 0 to 2*M, so if the answer is greater or equal to M, you need a final subtraction stage.

    In other words, if you test your algorithm with randomly chosen X and Y you should find that half of the time it gives the correct answer.

    From section 2 of the paper:

    Thus only one conditional subtraction is necessary to bring S[n] to the required range 0 ≤ S[n] < M. This subtraction will be omitted in the subsequent discussion since it is independent of the specific algorithm and architecture and can be treated as a part of post processing.