Search code examples
javamathbitsqrt

Understanding the strictMath java library


I got bored and decided to dive into remaking the square root function without referencing any of the Math.java functions. I have gotten to this point:

package sqrt;
public class SquareRoot {

public static void main(String[] args) {
    System.out.println(sqrtOf(8));

}

public static double sqrtOf(double n){
    double x = log(n,2);
    return powerOf(2, x/2);
}

public static double log(double n, double base)
{
    return (Math.log(n)/Math.log(base));
}

public static double powerOf(double x, double y) {
    return powerOf(e(),y *  log(x, e()));
}   

public static int factorial(int n){
    if(n <= 1){
        return 1;
    }else{
        return n * factorial((n-1));
    }
}

public static double e(){
    return 1/factorial(1);
}
public static double e(int precision){
    return 1/factorial(precision);
}

}

As you may very well see, I came to the point in my powerOf() function that infinitely recalls itself. I could replace that and use Math.exp(y * log(x, e()), so I dived into the Math source code to see how it handled my problem, resulting in a goose chase.

public static double exp(double a) {
     return StrictMath.exp(a); // default impl. delegates to StrictMath
 }

which leads to:

public static double exp(double x)
{
 if (x != x)
   return x;
 if (x > EXP_LIMIT_H)
   return Double.POSITIVE_INFINITY;
 if (x < EXP_LIMIT_L)
   return 0;

 // Argument reduction.
 double hi;
 double lo;
 int k;
 double t = abs(x);
 if (t > 0.5 * LN2)
   {
     if (t < 1.5 * LN2)
       {
         hi = t - LN2_H;
         lo = LN2_L;
         k = 1;
       }
     else
       {
         k = (int) (INV_LN2 * t + 0.5);
         hi = t - k * LN2_H;
         lo = k * LN2_L;
       }
     if (x < 0)
       {
         hi = -hi;
         lo = -lo;
         k = -k;
       }
     x = hi - lo;
   }
 else if (t < 1 / TWO_28)
   return 1;
 else
   lo = hi = k = 0;

// Now x is in primary range.
 t = x * x;
 double c = x - t * (P1 + t * (P2 + t * (P3 + t * (P4 + t * P5))));
 if (k == 0)
   return 1 - (x * c / (c - 2) - x);
 double y = 1 - (lo - x * c / (2 - c) - hi);
 return scale(y, k);

}

Values that are referenced:

 LN2 = 0.6931471805599453, // Long bits 0x3fe62e42fefa39efL.
 LN2_H = 0.6931471803691238, // Long bits 0x3fe62e42fee00000L.
 LN2_L = 1.9082149292705877e-10, // Long bits 0x3dea39ef35793c76L.
 INV_LN2 = 1.4426950408889634, // Long bits 0x3ff71547652b82feL.
 INV_LN2_H = 1.4426950216293335, // Long bits 0x3ff7154760000000L.
 INV_LN2_L = 1.9259629911266175e-8; // Long bits 0x3e54ae0bf85ddf44L.
 P1 = 0.16666666666666602, // Long bits 0x3fc555555555553eL.
 P2 = -2.7777777777015593e-3, // Long bits 0xbf66c16c16bebd93L.
 P3 = 6.613756321437934e-5, // Long bits 0x3f11566aaf25de2cL.
 P4 = -1.6533902205465252e-6, // Long bits 0xbebbbd41c5d26bf1L.
 P5 = 4.1381367970572385e-8, // Long bits 0x3e66376972bea4d0L.
 TWO_28 = 0x10000000, // Long bits 0x41b0000000000000L

Here is where I'm starting to get lost. But I can make a few assumptions that so far the answer is starting to become estimated. I then find myself here:

private static double scale(double x, int n)
{
  if (Configuration.DEBUG && abs(n) >= 2048)
    throw new InternalError("Assertion failure");
  if (x == 0 || x == Double.NEGATIVE_INFINITY
      || ! (x < Double.POSITIVE_INFINITY) || n == 0)
    return x;
  long bits = Double.doubleToLongBits(x);
  int exp = (int) (bits >> 52) & 0x7ff;
  if (exp == 0) // Subnormal x.
    {
      x *= TWO_54;
      exp = ((int) (Double.doubleToLongBits(x) >> 52) & 0x7ff) - 54;
    }
  exp += n;
  if (exp > 0x7fe) // Overflow.
    return Double.POSITIVE_INFINITY * x;
  if (exp > 0) // Normal.
    return Double.longBitsToDouble((bits & 0x800fffffffffffffL)
                                   | ((long) exp << 52));
  if (exp <= -54)
    return 0 * x; // Underflow.
  exp += 54; // Subnormal result.
  x = Double.longBitsToDouble((bits & 0x800fffffffffffffL)
                              | ((long) exp << 52));
  return x * (1 / TWO_54);
}


 TWO_54 = 0x40000000000000L

While I am, I would say, very understanding of math and programming, I hit the point to where I find myself at a Frankenstein monster mix of the two. I noticed the intrinsic switch to bits (which I have little to no experience with), and I was hoping someone could explain to me the processes that are occurring "under the hood" so to speak. Specifically where I got lost is from "Now x is in primary range" in the exp() method on wards and what the values that are being referenced really represent. I'm was asking for someone to help me understand not only the methods themselves, but also how they arrive to the answer. Feel free to go as in depth as needed.

edit: if someone could maybe make this tag: "strictMath" that would be great. I believe that its size and for the Math library deriving from it justifies its existence.


Solution

  • To the exponential function:

    What happens is that

    exp(x) = 2^k * exp(x-k*log(2))
    

    is exploited for positive x. Some magic is used to get more consistent results for large x where the reduction x-k*log(2) will introduce cancellation errors.

    On the reduced x a rational approximation with minimized maximal error over the interval 0.5..1.5 is used, see Pade approximations and similar. This is based on the symmetric formula

    exp(x) = exp(x/2)/exp(-x/2) = (c(x²)+x)/(c(x²)-x)
    

    (note that the c in the code is x+c(x)-2). When using Taylor series, approximations for c(x*x)=x*coth(x/2) are based on

    c(u)=2 + 1/6*u - 1/360*u^2 + 1/15120*u^3 - 1/604800*u^4 + 1/23950080*u^5 - 691/653837184000*u^6
    

    The scale(x,n) function implements the multiplication x*2^n by directly manipulating the exponent in the bit assembly of the double floating point format.


    Computing square roots

    To compute square roots it would be more advantageous to compute them directly. First reduce the interval of approximation arguments via

    sqrt(x)=2^k*sqrt(x/4^k)
    

    which can again be done efficiently by directly manipulating the bit format of double.

    After x is reduced to the interval 0.5..2.0 one can then employ formulas of the form

    u = (x-1)/(x+1)
    
    y = (c(u*u)+u) / (c(u*u)-u)
    

    based on

    sqrt(x)=sqrt(1+u)/sqrt(1-u)
    

    and

    c(v) = 1+sqrt(1-v) = 2 - 1/2*v - 1/8*v^2 - 1/16*v^3 - 5/128*v^4 - 7/256*v^5 - 21/1024*v^6 - 33/2048*v^7 - ...
    

    In a program without bit manipulations this could look like

    double my_sqrt(double x) {
        double c,u,v,y,scale=1;
        int k=0;
        if(x<0) return NaN;
        while(x>2  ) { x/=4; scale *=2; k++; }
        while(x<0.5) { x*=4; scale /=2; k--; }
        // rational approximation of sqrt
        u = (x-1)/(x+1); 
        v = u*u;
        c = 2 - v/2*(1 + v/4*(1 + v/2));
        y = 1 + 2*u/(c-u); // = (c+u)/(c-u);
        // one Halley iteration
        y = y*(1+8*x/(3*(3*y*y+x))) // = y*(y*y+3*x)/(3*y*y+x)
        // reconstruct original scale
        return y*scale;
    }
    

    One could replace the Halley step with two Newton steps, or with a better uniform approximation in c one could replace the Halley step with one Newton step, or ...