HiI saw an online answer for counting the distinct prime-factors of a number, and it looked non-optimal. So I tried to improve it, but in a simple benchmark, my variant is much slower than the original.
The algorithm counts the distinct prime factors of a number. The original uses a HashSet to collect the factors, then uses size to get their number. My "improved" version uses an int counter, and breaks up while loops into if/while to avoid unnecessary calls.
Update: tl/dr (see accepted answer for details)
The original code had a performance bug calling Math.sqrt unnecessarily that the compiler fixed:
int n = ...;
// sqrt does not need to be recomputed if n does not change
for (int i = 3; i <= Math.sqrt(n); i += 2) {
while (n % i == 0) {
n /= i;
}
}
The compiler optimized the sqrt call to only happen when n changes. But by making the loop contents a little more complex (no functional change though), the compiler stopped optimizing that way, and sqrt was called on every iteration.
Original question
public class PrimeFactors {
// fast version, takes 10s for input 8
static int countPrimeFactorsSet(int n) {
Set<Integer> primeFactorSet = new HashSet<>();
while (n % 2 == 0) {
primeFactorSet.add(2);
n /= 2;
}
for (int i = 3; i <= Math.sqrt(n); i += 2) {
while (n % i == 0) {
primeFactorSet.add(i);
n /= i;
}
}
if (n > 2) {
primeFactorSet.add(n);
}
return primeFactorSet.size();
}
// slow version, takes 19s for input 8
static int countPrimeFactorsCounter(int n) {
int count = 0; // using simple int
if (n % 2 == 0) {
count ++; // only add on first division
n /= 2;
while (n % 2 == 0) {
n /= 2;
}
}
for (int i = 3; i <= Math.sqrt(n); i += 2) {
if (n % i == 0) {
count++; // only add on first division
n /= i;
while (n % i == 0) {
n /= i;
}
}
}
if (n > 2) {
count++;
}
return count;
}
static int findNumberWithNPrimeFactors(final int n) {
for (int i = 3; ; i++) {
// switch implementations
if (countPrimeFactorsCounter(i) == n) {
// if (countPrimeFactorsSet(i) == n) {
return i;
}
}
}
public static void main(String[] args) {
findNumberWithNPrimeFactors(8); // benchmark warmup
findNumberWithNPrimeFactors(8);
long start = System.currentTimeMillis();
int result = findNumberWithNPrimeFactors(n);
long duration = System.currentTimeMillis() - start;
System.out.println("took ms " + duration + " to find " + result);
}
}
The output for the original version is consistently around 10s (on java8), whereas the "optimized" version is closer to 20s (both print the same result). Actually, just changing the single while-loop to an if-block with a contained while-loop already slows down the original method to half the speed.
Using -Xint
to run the JVM in interpreted mode, the optimized version runs 3 times faster. Using -Xcomp
makes both implementations run at similar speed.
So it seems the JIT can optimize the version with a single while-loop and a HashSet more than the version with a simple int counter.
Would a proper microbenchmark (How do I write a correct micro-benchmark in Java?) tell me something else? Is there a performance optimization principle I overlooked (e.g. Java performance tips)?
I converted your example into JMH benchmark to make fair measurements, and indeed the set
variant appeared twice as fast as counter
:
Benchmark Mode Cnt Score Error Units
PrimeFactors.counter thrpt 5 717,976 ± 7,232 ops/ms
PrimeFactors.set thrpt 5 1410,705 ± 15,894 ops/ms
To find out the reason, I reran the benchmark with built-in -prof xperfasm
profiler. It happened that counter
method spent more than 60% time executing vsqrtsd
instruction - obviously, the compiled counterpart of Math.sqrt(n)
.
0,02% │ │ │ │ 0x0000000002ab8f3e: vsqrtsd %xmm0,%xmm0,%xmm0 <-- Math.sqrt
61,27% │ │ │ │ 0x0000000002ab8f42: vcvtsi2sd %r10d,%xmm1,%xmm1
At the same time the hottest instruction of the set
method was idiv
, the result of n % i
compilation.
│ │ ││ 0x0000000002ecb9e7: idiv %ebp ;*irem
55,81% │ ↘ ↘│ 0x0000000002ecb9e9: test %edx,%edx
It's not a surprise that Math.sqrt
is a slow operation. But why it was executed more frequently in the first case?
The clue is the transformation of the code you made during optimization. You wrapped a simple while
loop into an extra if
block. This made the control flow a little more complex, so that JIT failed to hoist Math.sqrt
computation out of the loop and had to recompute it on every iteration.
We need to help JIT compiler a bit in order to bring the performance back. Let's hoist Math.sqrt
computation out of the loop manually.
static int countPrimeFactorsSet(int n) {
Set<Integer> primeFactorSet = new HashSet<>();
while (n % 2 == 0) {
primeFactorSet.add(2);
n /= 2;
}
double sn = Math.sqrt(n); // compute Math.sqrt out of the loop
for (int i = 3; i <= sn; i += 2) {
while (n % i == 0) {
primeFactorSet.add(i);
n /= i;
}
sn = Math.sqrt(n); // recompute after n changes
}
if (n > 2) {
primeFactorSet.add(n);
}
return primeFactorSet.size();
}
static int countPrimeFactorsCounter(int n) {
int count = 0; // using simple int
if (n % 2 == 0) {
count ++; // only add on first division
n /= 2;
while (n % 2 == 0) {
n /= 2;
}
}
double sn = Math.sqrt(n); // compute Math.sqrt out of the loop
for (int i = 3; i <= sn; i += 2) {
if (n % i == 0) {
count++; // only add on first division
n /= i;
while (n % i == 0) {
n /= i;
}
sn = Math.sqrt(n); // recompute after n changes
}
}
if (n > 2) {
count++;
}
return count;
}
Now counter
method became fast! Even a bit faster than set
(which is quite expected, because it does the same amount of computation, excluding the Set overhead).
Benchmark Mode Cnt Score Error Units
PrimeFactors.counter thrpt 5 1513,228 ± 13,046 ops/ms
PrimeFactors.set thrpt 5 1411,573 ± 10,004 ops/ms
Note that set
performance did not change, because JIT was able to do the same optimization itself, thanks to a simpler control flow graph.
Conclusion: Java performance is a really complicated thing, especially when talking about micro-optimizations. JIT optimizations are fragile, and it's hard to understand JVM's mind without specialized tools like JMH and profilers.