Search code examples
javajvmjit

Java 8u40 Math.round() very slow


I have a fairly simple hobby project written in Java 8 that makes extensive use of repeated Math.round() calls in one of its modes of operation. For example, one such mode spawns 4 threads and queues 48 runnable tasks by way of an ExecutorService, each of which runs something similar to the following block of code 2^31 times:

int3 = Math.round(float1 + float2);
int3 = Math.round(float1 * float2);
int3 = Math.round(float1 / float2);

That's not exactly how it is (there are arrays involved, and nested loops), but you get the idea. Anyway, prior to Java 8u40, the code that resembles the above could complete the full run of ~103 billion instruction blocks in about 13 seconds on an AMD A10-7700k. With Java 8u40 it takes around 260 seconds to do the same thing. No changes to code, no nothing, just a Java update.

Has anyone else noticed Math.round() getting much slower, especially when it is used repetitiously? It's almost as though the JVM was doing some sort of optimization before that it isn't doing anymore. Maybe it was using SIMD prior to 8u40 and it isn't now?

edit: I have completed my second attempt at an MVCE. You can download the first attempt here:

https://www.dropbox.com/s/rm2ftcv8y6ye1bi/MathRoundMVCE.zip?dl=0

The second attempt is below. My first attempt has been otherwise removed from this post as it has been deemed to be too long, and is prone to dead code removal optimizations by the JVM (which apparently are happening less in 8u40).

import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

public class MathRoundMVCE
{           
    static long grandtotal = 0;
    static long sumtotal = 0;

    static float[] float4 = new float[128];
    static float[] float5 = new float[128];
    static int[] int6 = new int[128];
    static int[] int7 = new int[128];
    static int[] int8 = new int[128];
    static long[] longarray = new long[480];

    final static int mil = 1000000;

    public static void main(String[] args)
    {       
        initmainarrays();
        OmniCode omni = new OmniCode();
        grandtotal = omni.runloops() / mil;
        System.out.println("Total sum of operations is " + sumtotal);
        System.out.println("Total execution time is " + grandtotal + " milliseconds");
    }   

    public static long siftarray(long[] larray)
    {
        long topnum = 0;
        long tempnum = 0;
        for (short i = 0; i < larray.length; i++)
        {
            tempnum = larray[i];
            if (tempnum > 0)
            {
                topnum += tempnum;
            }
        }
        topnum = topnum / Runtime.getRuntime().availableProcessors();
        return topnum;
    }

    public static void initmainarrays()
    {
        int k = 0;

        do
        {           
            float4[k] = (float)(Math.random() * 12) + 1f;
            float5[k] = (float)(Math.random() * 12) + 1f;
            int6[k] = 0;

            k++;
        }
        while (k < 128);        
    }       
}

class OmniCode extends Thread
{           
    volatile long totaltime = 0;
    final int standard = 16777216;
    final int warmup = 200000;

    byte threads = 0;

    public long runloops()
    {
        this.setPriority(MIN_PRIORITY);

        threads = (byte)Runtime.getRuntime().availableProcessors();
        ExecutorService executor = Executors.newFixedThreadPool(threads);

        for (short j = 0; j < 48; j++)
        {           
            executor.execute(new RoundFloatToIntAlternate(warmup, (byte)j));
        }

        executor.shutdown();

        while (!executor.isTerminated())
        {
            try
            {
                Thread.sleep(100);
            } 
            catch (InterruptedException e)
            {
                //Do nothing                
            }
        }

        executor = Executors.newFixedThreadPool(threads);

        for (short j = 0; j < 48; j++)
        {           
            executor.execute(new RoundFloatToIntAlternate(standard, (byte)j));          
        }

        executor.shutdown();

        while (!executor.isTerminated())
        {
            try
            {
                Thread.sleep(100);
            } 
            catch (InterruptedException e)
            {
                //Do nothing                
            }
        }

        totaltime = MathRoundMVCE.siftarray(MathRoundMVCE.longarray);   

        executor = null;
        Runtime.getRuntime().gc();
        return totaltime;
    }
}

class RoundFloatToIntAlternate extends Thread
{       
    int i = 0;
    int j = 0;
    int int3 = 0;
    int iterations = 0;
    byte thread = 0;

    public RoundFloatToIntAlternate(int cycles, byte threadnumber)
    {
        iterations = cycles;
        thread = threadnumber;
    }

    public void run()
    {
        this.setPriority(9);
        MathRoundMVCE.longarray[this.thread] = 0;
        mainloop();
        blankloop();    

    }

    public void blankloop()
    {
        j = 0;
        long timer = 0;
        long totaltimer = 0;

        do
        {   
            timer = System.nanoTime();
            i = 0;

            do
            {
                i++;
            }
            while (i < 128);
            totaltimer += System.nanoTime() - timer;            

            j++;
        }
        while (j < iterations);         

        MathRoundMVCE.longarray[this.thread] -= totaltimer;
    }

    public void mainloop()
    {
        j = 0;
        long timer = 0; 
        long totaltimer = 0;
        long localsum = 0;

        int[] int6 = new int[128];
        int[] int7 = new int[128];
        int[] int8 = new int[128];

        do
        {   
            timer = System.nanoTime();
            i = 0;

            do
            {
                int6[i] = Math.round(MathRoundMVCE.float4[i] + MathRoundMVCE.float5[i]);
                int7[i] = Math.round(MathRoundMVCE.float4[i] * MathRoundMVCE.float5[i]);
                int8[i] = Math.round(MathRoundMVCE.float4[i] / MathRoundMVCE.float5[i]);

                i++;
            }
            while (i < 128);
            totaltimer += System.nanoTime() - timer;

            for(short z = 0; z < 128; z++)
            {
                localsum += int6[z] + int7[z] + int8[z];
            }       

            j++;
        }
        while (j < iterations);         

        MathRoundMVCE.longarray[this.thread] += totaltimer;
        MathRoundMVCE.sumtotal = localsum;
    }
}

Long story short, this code performed about the same in 8u25 as in 8u40. As you can see, I am now recording the results of all calculations into arrays, and then summing those arrays outside of the timed portion of the loop to a local variable which is then written to a static variable at the end of the outer loop.

Under 8u25: Total execution time is 261545 milliseconds

Under 8u40: Total execution time is 266890 milliseconds

Test conditions were the same as before. So, it would appear that 8u25 and 8u31 were doing dead code removal that 8u40 stopped doing, causing the code to "slow down" in 8u40. That doesn't explain every weird little thing that's cropped up but that appears to be the bulk of it. As an added bonus, the suggestions and answers provided here have given me inspiration to improve the other parts of my hobby project, for which I am quite grateful. Thank you all for that!


Solution

  • MVCE based on OP

    • can likely be simplified further
    • changed int3 = statements to int3 += to reduce chance of dead code removal. int3 = difference from 8u31 to 8u40 is factor 3x slower. Using int3 += difference is only 15% slower.
    • print result to further reduce chance of dead code removal optimisations

    Code

    public class MathTime {
        static float[][] float1 = new float[8][16];
        static float[][] float2 = new float[8][16];
    
        public static void main(String[] args) {
            for (int j = 0; j < 8; j++) {
                for (int k = 0; k < 16; k++) {
                    float1[j][k] = (float) (j + k);
                    float2[j][k] = (float) (j + k);
                }
            }
            new Test().run();
        }
    
        private static class Test {
            int int3;
    
            public void run() {
                for (String test : new String[] { "warmup", "real" }) {
    
                    long t0 = System.nanoTime();
    
                    for (int count = 0; count < 1e7; count++) {
                        int i = count % 8;
                        int3 += Math.round(float1[i][0] + float2[i][0]);
                        int3 += Math.round(float1[i][1] + float2[i][1]);
                        int3 += Math.round(float1[i][2] + float2[i][2]);
                        int3 += Math.round(float1[i][3] + float2[i][3]);
                        int3 += Math.round(float1[i][4] + float2[i][4]);
                        int3 += Math.round(float1[i][5] + float2[i][5]);
                        int3 += Math.round(float1[i][6] + float2[i][6]);
                        int3 += Math.round(float1[i][7] + float2[i][7]);
                        int3 += Math.round(float1[i][8] + float2[i][8]);
                        int3 += Math.round(float1[i][9] + float2[i][9]);
                        int3 += Math.round(float1[i][10] + float2[i][10]);
                        int3 += Math.round(float1[i][11] + float2[i][11]);
                        int3 += Math.round(float1[i][12] + float2[i][12]);
                        int3 += Math.round(float1[i][13] + float2[i][13]);
                        int3 += Math.round(float1[i][14] + float2[i][14]);
                        int3 += Math.round(float1[i][15] + float2[i][15]);
    
                        int3 += Math.round(float1[i][0] * float2[i][0]);
                        int3 += Math.round(float1[i][1] * float2[i][1]);
                        int3 += Math.round(float1[i][2] * float2[i][2]);
                        int3 += Math.round(float1[i][3] * float2[i][3]);
                        int3 += Math.round(float1[i][4] * float2[i][4]);
                        int3 += Math.round(float1[i][5] * float2[i][5]);
                        int3 += Math.round(float1[i][6] * float2[i][6]);
                        int3 += Math.round(float1[i][7] * float2[i][7]);
                        int3 += Math.round(float1[i][8] * float2[i][8]);
                        int3 += Math.round(float1[i][9] * float2[i][9]);
                        int3 += Math.round(float1[i][10] * float2[i][10]);
                        int3 += Math.round(float1[i][11] * float2[i][11]);
                        int3 += Math.round(float1[i][12] * float2[i][12]);
                        int3 += Math.round(float1[i][13] * float2[i][13]);
                        int3 += Math.round(float1[i][14] * float2[i][14]);
                        int3 += Math.round(float1[i][15] * float2[i][15]);
    
                        int3 += Math.round(float1[i][0] / float2[i][0]);
                        int3 += Math.round(float1[i][1] / float2[i][1]);
                        int3 += Math.round(float1[i][2] / float2[i][2]);
                        int3 += Math.round(float1[i][3] / float2[i][3]);
                        int3 += Math.round(float1[i][4] / float2[i][4]);
                        int3 += Math.round(float1[i][5] / float2[i][5]);
                        int3 += Math.round(float1[i][6] / float2[i][6]);
                        int3 += Math.round(float1[i][7] / float2[i][7]);
                        int3 += Math.round(float1[i][8] / float2[i][8]);
                        int3 += Math.round(float1[i][9] / float2[i][9]);
                        int3 += Math.round(float1[i][10] / float2[i][10]);
                        int3 += Math.round(float1[i][11] / float2[i][11]);
                        int3 += Math.round(float1[i][12] / float2[i][12]);
                        int3 += Math.round(float1[i][13] / float2[i][13]);
                        int3 += Math.round(float1[i][14] / float2[i][14]);
                        int3 += Math.round(float1[i][15] / float2[i][15]);
    
                    }
                    long t1 = System.nanoTime();
                    System.out.println(int3);
                    System.out.println(String.format("%s, Math.round(float), %s, %.1f ms", System.getProperty("java.version"), test, (t1 - t0) / 1e6));
                }
            }
        }
    }
    

    Results

    adam@brimstone:~$ ./jdk1.8.0_40/bin/javac MathTime.java;./jdk1.8.0_40/bin/java -cp . MathTime 
    1.8.0_40, Math.round(float), warmup, 6846.4 ms
    1.8.0_40, Math.round(float), real, 6058.6 ms
    adam@brimstone:~$ ./jdk1.8.0_31/bin/javac MathTime.java;./jdk1.8.0_31/bin/java -cp . MathTime 
    1.8.0_31, Math.round(float), warmup, 5717.9 ms
    1.8.0_31, Math.round(float), real, 5282.7 ms
    adam@brimstone:~$ ./jdk1.8.0_25/bin/javac MathTime.java;./jdk1.8.0_25/bin/java -cp . MathTime 
    1.8.0_25, Math.round(float), warmup, 5702.4 ms
    1.8.0_25, Math.round(float), real, 5262.2 ms
    

    Observations

    • For trivial uses of Math.round(float) I can find no difference in performance on my platform (Linux x86_64). There is only a difference in benchmark, my previous naive and incorrect benchmarks only exposed differences in behaviour in optimisation as Ivan's answer and Marco13's comments point out.
    • 8u40 is less aggressive in dead code elimination than previous versions, meaning more code is executed in some corner cases and hence slower.
    • 8u40 takes slightly longer to warm up, but once "there", quicker.

    Source analysis

    Surprisingly Math.round(float) is a pure Java implementation rather than native, the code for both 8u31 and 8u40 is identical.

    diff  jdk1.8.0_31/src/java/lang/Math.java jdk1.8.0_40/src/java/lang/Math.java
    -no differences-
    
    public static int round(float a) {
        int intBits = Float.floatToRawIntBits(a);
        int biasedExp = (intBits & FloatConsts.EXP_BIT_MASK)
                >> (FloatConsts.SIGNIFICAND_WIDTH - 1);
        int shift = (FloatConsts.SIGNIFICAND_WIDTH - 2
                + FloatConsts.EXP_BIAS) - biasedExp;
        if ((shift & -32) == 0) { // shift >= 0 && shift < 32
            // a is a finite number such that pow(2,-32) <= ulp(a) < 1
            int r = ((intBits & FloatConsts.SIGNIF_BIT_MASK)
                    | (FloatConsts.SIGNIF_BIT_MASK + 1));
            if (intBits < 0) {
                r = -r;
            }
            // In the comments below each Java expression evaluates to the value
            // the corresponding mathematical expression:
            // (r) evaluates to a / ulp(a)
            // (r >> shift) evaluates to floor(a * 2)
            // ((r >> shift) + 1) evaluates to floor((a + 1/2) * 2)
            // (((r >> shift) + 1) >> 1) evaluates to floor(a + 1/2)
            return ((r >> shift) + 1) >> 1;
        } else {
            // a is either
            // - a finite number with abs(a) < exp(2,FloatConsts.SIGNIFICAND_WIDTH-32) < 1/2
            // - a finite number with ulp(a) >= 1 and hence a is a mathematical integer
            // - an infinity or NaN
            return (int) a;
        }
    }