Search code examples
javamathprobabilitynumerical-stability

Numerical accuracy with log probability Java implementation


Sometimes when you do calculations with very small probabilities using common data types such as doubles, numerical inaccuracies cascade over multiple calculations and lead to incorrect results. Because of this it is recommended to use log probabilities, which improve numerical stability. I have implemented log probabilities in Java and my implementation works, but it has worse numerical stability than using raw doubles. What is wrong with my implementation? What is an accurate and efficient way to perform many consecutive calculations with small probabilities in Java?

I'm unable to provide a neatly contained demonstration of this problem because the inaccuracies cascade over many calculations. However, here is proof that a problem exists: this submission to a CodeForces contest fails due to numerical accuracy. Running test #7 and adding debug prints clearly show that from day 1774, numerical errors begin cascading until the sum of probabilities drops to 0 (when it should be 1). After replacing my Prob class with a simple wrapper over doubles the exact same solution passes tests.

My implementation of multiplying probabilities:

a * b = Math.log(a) + Math.log(b)

My implementation of addition:

a + b = Math.log(a) + Math.log(1 + Math.exp(Math.log(b) - Math.log(a)))

The stability problem is most likely contained within those 2 lines, but here is my entire implementation:

class Prob {

        /** Math explained: https://en.wikipedia.org/wiki/Log_probability
         *  Quick start:
         *      - Instantiate probabilities, eg. Prob a = new Prob(0.75)
         *      - add(), multiply() return new objects, can perform on nulls & NaNs.
         *      - get() returns probability as a readable double */

        /** Logarithmized probability. Note: 0% represented by logP NaN. */
        private double logP;

        /** Construct instance with real probability. */
        public Prob(double real) {
            if (real > 0) this.logP = Math.log(real);
            else this.logP = Double.NaN;
        }

        /** Construct instance with already logarithmized value. */
        static boolean dontLogAgain = true;
        public Prob(double logP, boolean anyBooleanHereToChooseThisConstructor) {
            this.logP = logP;
        }

        /** Returns real probability as a double. */
        public double get() {
            return Math.exp(logP);
        }

        @Override
        public String toString() {
            return ""+get();
        }

        /***************** STATIC METHODS BELOW ********************/

        /** Note: returns NaN only when a && b are both NaN/null. */
        public static Prob add(Prob a, Prob b) {
            if (nullOrNaN(a) && nullOrNaN(b)) return new Prob(Double.NaN, dontLogAgain);
            if (nullOrNaN(a)) return copy(b);
            if (nullOrNaN(b)) return copy(a);

            double x = a.logP;
            double y = b.logP;
            double sum = x + Math.log(1 + Math.exp(y - x));
            return new Prob(sum, dontLogAgain);
        }

        /** Note: multiplying by null or NaN produces NaN (repping 0% real prob). */
        public static Prob multiply(Prob a, Prob b) {
            if (nullOrNaN(a) || nullOrNaN(b)) return new Prob(Double.NaN, dontLogAgain);
            return new Prob(a.logP + b.logP, dontLogAgain);
        }

        /** Returns true if p is null or NaN. */
        private static boolean nullOrNaN(Prob p) {
            return (p == null || Double.isNaN(p.logP));
        }

        /** Returns a new instance with the same value as original. */
        private static Prob copy(Prob original) {
            return new Prob(original.logP, dontLogAgain);
        }
    }

Solution

  • Problem was caused by the way Math.exp(z) was used in this line:

    a + b = Math.log(a) + Math.log(1 + Math.exp(Math.log(b) - Math.log(a)))

    When z reaches extreme values, numerical accuracy of double is not enough for the output of Math.exp(z). This causes us to lose information, produce an inaccurate result, and then these results cascade over multiple calculations.

    When z >= 710 then Math.exp(z) = Infinity

    When z <= -746 then Math.exp(z) = 0

    In the original code I was calling Math.exp with y - x and arbitrarily choosing which is x and which is why. Let's instead choose y and x based on which is larger, so that z is negative rather than positive. The point where we get overflow is further on the negative side (746 rather than 710) and more importantly, when we overflow, we end up at 0 rather than infinity. Which is what we want with a low probability.

    double x = Math.max(a.logP, b.logP);
    double y = Math.min(a.logP, b.logP);
    double sum = x + Math.log(1 + Math.exp(y - x));