Search code examples
javaperformancecompiler-optimizationhoistingloop-invariant

Why does optimized prime-factor counting algorithm run slower


HiI saw an online answer for counting the distinct prime-factors of a number, and it looked non-optimal. So I tried to improve it, but in a simple benchmark, my variant is much slower than the original.

The algorithm counts the distinct prime factors of a number. The original uses a HashSet to collect the factors, then uses size to get their number. My "improved" version uses an int counter, and breaks up while loops into if/while to avoid unnecessary calls.

Update: tl/dr (see accepted answer for details)

The original code had a performance bug calling Math.sqrt unnecessarily that the compiler fixed:

int n = ...;
// sqrt does not need to be recomputed if n does not change
for (int i = 3; i <= Math.sqrt(n); i += 2) {
    while (n % i == 0) {
        n /= i;
    }
}

The compiler optimized the sqrt call to only happen when n changes. But by making the loop contents a little more complex (no functional change though), the compiler stopped optimizing that way, and sqrt was called on every iteration.

Original question

public class PrimeFactors {

    // fast version, takes 10s for input 8
    static int countPrimeFactorsSet(int n) {
        Set<Integer> primeFactorSet = new HashSet<>();
        while (n % 2 == 0) {
            primeFactorSet.add(2);
            n /= 2;
        }
        for (int i = 3; i <= Math.sqrt(n); i += 2) {
            while (n % i == 0) {
                primeFactorSet.add(i);
                n /= i;
            }
        }
        if (n > 2) {
            primeFactorSet.add(n);
        }
        return primeFactorSet.size();
    }

    // slow version, takes 19s for input 8
    static int countPrimeFactorsCounter(int n) {
        int count = 0; // using simple int
        if (n % 2 == 0) {
            count ++; // only add on first division
            n /= 2;
            while (n % 2 == 0) {
                n /= 2;
            }
        }
        for (int i = 3; i <= Math.sqrt(n); i += 2) {
            if (n % i == 0) {
                count++; // only add on first division
                n /= i;
                while (n % i == 0) {
                    n /= i;
                }
            }
        }
        if (n > 2) {
            count++;
        }
        return count;
    }

    static int findNumberWithNPrimeFactors(final int n) {
        for (int i = 3; ; i++) {
            // switch implementations
            if (countPrimeFactorsCounter(i) == n) {
            // if (countPrimeFactorsSet(i) == n) {
                return i;
            }
        }
    }

    public static void main(String[] args) {
        findNumberWithNPrimeFactors(8); // benchmark warmup
        findNumberWithNPrimeFactors(8);
        long start = System.currentTimeMillis();
        int result = findNumberWithNPrimeFactors(n);
        long duration = System.currentTimeMillis() - start;

        System.out.println("took ms " + duration + " to find " + result);
    }
}

The output for the original version is consistently around 10s (on java8), whereas the "optimized" version is closer to 20s (both print the same result). Actually, just changing the single while-loop to an if-block with a contained while-loop already slows down the original method to half the speed.

Using -Xint to run the JVM in interpreted mode, the optimized version runs 3 times faster. Using -Xcomp makes both implementations run at similar speed. So it seems the JIT can optimize the version with a single while-loop and a HashSet more than the version with a simple int counter.

Would a proper microbenchmark (How do I write a correct micro-benchmark in Java?) tell me something else? Is there a performance optimization principle I overlooked (e.g. Java performance tips)?


Solution

  • I converted your example into JMH benchmark to make fair measurements, and indeed the set variant appeared twice as fast as counter:

    Benchmark              Mode  Cnt     Score    Error   Units
    PrimeFactors.counter  thrpt    5   717,976 ±  7,232  ops/ms
    PrimeFactors.set      thrpt    5  1410,705 ± 15,894  ops/ms
    

    To find out the reason, I reran the benchmark with built-in -prof xperfasm profiler. It happened that counter method spent more than 60% time executing vsqrtsd instruction - obviously, the compiled counterpart of Math.sqrt(n).

      0,02%   │  │  │     │  0x0000000002ab8f3e: vsqrtsd %xmm0,%xmm0,%xmm0    <-- Math.sqrt
     61,27%   │  │  │     │  0x0000000002ab8f42: vcvtsi2sd %r10d,%xmm1,%xmm1
    

    At the same time the hottest instruction of the set method was idiv, the result of n % i compilation.

                 │  │ ││  0x0000000002ecb9e7: idiv   %ebp               ;*irem
     55,81%      │  ↘ ↘│  0x0000000002ecb9e9: test   %edx,%edx
    

    It's not a surprise that Math.sqrt is a slow operation. But why it was executed more frequently in the first case?

    The clue is the transformation of the code you made during optimization. You wrapped a simple while loop into an extra if block. This made the control flow a little more complex, so that JIT failed to hoist Math.sqrt computation out of the loop and had to recompute it on every iteration.

    We need to help JIT compiler a bit in order to bring the performance back. Let's hoist Math.sqrt computation out of the loop manually.

        static int countPrimeFactorsSet(int n) {
            Set<Integer> primeFactorSet = new HashSet<>();
            while (n % 2 == 0) {
                primeFactorSet.add(2);
                n /= 2;
            }
            double sn = Math.sqrt(n);  // compute Math.sqrt out of the loop
            for (int i = 3; i <= sn; i += 2) {
                while (n % i == 0) {
                    primeFactorSet.add(i);
                    n /= i;
                }
                sn = Math.sqrt(n);     // recompute after n changes
            }
            if (n > 2) {
                primeFactorSet.add(n);
            }
            return primeFactorSet.size();
        }
    
        static int countPrimeFactorsCounter(int n) {
            int count = 0; // using simple int
            if (n % 2 == 0) {
                count ++; // only add on first division
                n /= 2;
                while (n % 2 == 0) {
                    n /= 2;
                }
            }
            double sn = Math.sqrt(n);  // compute Math.sqrt out of the loop
            for (int i = 3; i <= sn; i += 2) {
                if (n % i == 0) {
                    count++; // only add on first division
                    n /= i;
                    while (n % i == 0) {
                        n /= i;
                    }
                    sn = Math.sqrt(n);     // recompute after n changes
                }
            }
            if (n > 2) {
                count++;
            }
            return count;
        }
    

    Now counter method became fast! Even a bit faster than set (which is quite expected, because it does the same amount of computation, excluding the Set overhead).

    Benchmark              Mode  Cnt     Score    Error   Units
    PrimeFactors.counter  thrpt    5  1513,228 ± 13,046  ops/ms
    PrimeFactors.set      thrpt    5  1411,573 ± 10,004  ops/ms
    

    Note that set performance did not change, because JIT was able to do the same optimization itself, thanks to a simpler control flow graph.

    Conclusion: Java performance is a really complicated thing, especially when talking about micro-optimizations. JIT optimizations are fragile, and it's hard to understand JVM's mind without specialized tools like JMH and profilers.