Search code examples
c++algorithmmathrecursionnumber-theory

What exactly happens inside extended euclidean algorithm's recursion in c++?


I know what extended euclidean algorithm is and why it is used in programming. It is a very useful algorithm for finding inverse modulo of a number. I know how to implement it in c++ and this is how I implement it below in c++.

typedef pair<int, int> pii;

#define x first
#define y second

pii extendedEuclidean(int a, int b)
{
    if(b==0)
        return {a,0};
    else {
        pii d = extendedEuclidean(b, a%b);
        return {d.y, d.x - (d.y*(a/b))};
    }
}

Now if I want to find the inverse modulo of a number for example 13 where mod is for example 1000007, then I simply call this function by

pair<int, int> res = extendedEuclidean(13, 1000007);

Then the result is

res.first

My question is why and what exactly happens inside this recursion? And also why it produce correct result.

N.B: here gcd(a, b) must be 1.


Solution

  • The Euclidean algorithm calculates the greatest common divisor of a pair of numbers (a, b) (assuming that a>b). It uses the observation that any common divisor of a and b is also a divisor of a-b. Here is why:

    Let d be the divisor. Then, a can be expressed as a=d*k for an integer k and b=d*l for an integer l. Then, a-b=d*k-d*l=d*(k-l). k-l is again an integer. Thus, d must be a divisor of a-b.

    What the algorithm does is subtracting the smaller number from the bigger one as often as possible. This is the part a%b. E.g. if a = 25 and b = 7, a%b=4 is what you get after subtracting b 3 times from a. After that, the new a will be smaller than b. Therefore, you swap both numbers. This is the part where you call the recursion: extendedEuclidean(b, a%b);

    The extended Euclidean algorithm does a bit more. Additionally, it calculates two numbers x and y, such that gcd(a, b) = x * a + y * b. Here is how it's done:

    In the last iteration, you end up with a'=gcd and b'=0. Thus, you have gcd=a' * 1 + b' * 0, where 1 and 0 are x' and y', respectively. Assume that the values in the previous iteration were a'' and b''. Then we know that a'=b'' and b'=a'' % b''. With this, we find that b'=a''-(a''/b'')*b'' (subtract as often as possible). And we can modify

    gcd = a' * x' + b' * y'
    gcd = b'' * x' + (a''-(a''/b'')*b'') * y'
        = a'' * y' + b'' * (x' - y' * (a''/b''))
    

    Hence, the new x''=y' and y''=x' - y' * (a''/b''). This is your return statement return {d.y, d.x - (d.y*(a/b))};.

    An example:

    Let a=25, b=7. The first pass calculates the columns a and b (top to bottom). This accounts for the recursive calls. The second pass calculates the columns x and y (bottom to top). This accounts for the return statements:

     a  | b            |  x   | y                     | means
    ----+--------------+------------------------------+---------------------
     25 |  7           |  2   | -1 - 2 * (25/7) = -7  | 1 = 2 * 25 - 7 * 7 
      7 |  25 % 7 = 4  | -1   |  1 + 1 * (7/4)  =  2  | 1 = (-1) * 7 + 2 * 4
      4 |  7 % 4  = 3  |  1   |  0 - 1 * (4/3)  = -1  | 1 = 1 * 3 - 1 * 3
      3 |  4 % 3  = 1  |  0   |  1 - 0 * (3/1)  =  1  | 1 = 0 * 3 + 1 * 1
      1 |  3 % 1  = 0  |  1   |  0                    | 1 = 1 * 1 + 0 * 0
    

    So you get 1 = 2 * 25 - 7 * 7 where 2 is the result's .first and -7 is the result's .second. If we are in mod 25, this reduces to:

    1 == 2 * 0 - 7 * 7
    1 == -7 * 7
    

    Hence, -7 == 18 (which is result.second) is the multiplicative inverse of 7 (mod 25). Note that I swapped the input to avoid an unnecessary first iteration. Otherwise, it is result.first.