Search code examples
javaconcurrencypuzzleatomic

what is wrong with this thread-safe byte sequence generator?


I need a byte generator that would generate values from Byte.MIN_VALUE to Byte.MAX_VALUE. When it reaches MAX_VALUE, it should start over again from MIN_VALUE.

I have written the code using AtomicInteger (see below); however, the code does not seem to behave properly if accessed concurrently and if made artificially slow with Thread.sleep() (if no sleeping, it runs fine; however, I suspect it is just too fast for concurrency problems to show up).

The code (with some added debug code):

public class ByteGenerator {

    private static final int INITIAL_VALUE = Byte.MIN_VALUE-1;

    private AtomicInteger counter = new AtomicInteger(INITIAL_VALUE);
    private AtomicInteger resetCounter = new AtomicInteger(0);

    private boolean isSlow = false;
    private long startTime;

    public byte nextValue() {
        int next = counter.incrementAndGet();
        //if (isSlow) slowDown(5);
        if (next > Byte.MAX_VALUE) {
            synchronized(counter) {
                int i = counter.get();
                //if value is still larger than max byte value, we reset it
                if (i > Byte.MAX_VALUE) {
                    counter.set(INITIAL_VALUE);
                    resetCounter.incrementAndGet();
                    if (isSlow) slowDownAndLog(10, "resetting");
                } else {
                    if (isSlow) slowDownAndLog(1, "missed");
                }
                next = counter.incrementAndGet();
            }
        }
        return (byte) next;
    }

    private void slowDown(long millis) {
        try {
            Thread.sleep(millis);
        } catch (InterruptedException e) {
        }
    }
    private void slowDownAndLog(long millis, String msg) {
        slowDown(millis);
        System.out.println(resetCounter + " " 
                           + (System.currentTimeMillis()-startTime) + " "
                           + Thread.currentThread().getName() + ": " + msg);
    }

    public void setSlow(boolean isSlow) {
        this.isSlow = isSlow;
    }
    public void setStartTime(long startTime) {
        this.startTime = startTime;
    }

}

And, the test:

public class ByteGeneratorTest {

    @Test
    public void testGenerate() throws Exception {
        ByteGenerator g = new ByteGenerator();
        for (int n = 0; n < 10; n++) {
            for (int i = Byte.MIN_VALUE; i <= Byte.MAX_VALUE; i++) {
                assertEquals(i, g.nextValue());
            }
        }
    }

    @Test
    public void testGenerateMultiThreaded() throws Exception {
        final ByteGenerator g = new ByteGenerator();
        g.setSlow(true);
        final AtomicInteger[] counters = new AtomicInteger[Byte.MAX_VALUE-Byte.MIN_VALUE+1];
        for (int i = 0; i < counters.length; i++) {
            counters[i] = new AtomicInteger(0);
        }
        Thread[] threads = new Thread[100];
        final CountDownLatch latch = new CountDownLatch(threads.length);
        for (int i = 0; i < threads.length; i++) {
            threads[i] = new Thread(new Runnable() {
                public void run() {
                    try {
                        for (int i = Byte.MIN_VALUE; i <= Byte.MAX_VALUE; i++) {
                            byte value = g.nextValue();
                            counters[value-Byte.MIN_VALUE].incrementAndGet();
                        }
                    } finally {
                        latch.countDown();
                    }
                }
            }, "generator-client-" + i);
            threads[i].setDaemon(true);
        }
        g.setStartTime(System.currentTimeMillis());
        for (int i = 0; i < threads.length; i++) {
            threads[i].start();
        }
        latch.await();
        for (int i = 0; i < counters.length; i++) {
            System.out.println("value #" + (i+Byte.MIN_VALUE) + ": " + counters[i].get());
        }
        //print out the number of hits for each value
        for (int i = 0; i < counters.length; i++) {
            assertEquals("value #" + (i+Byte.MIN_VALUE), threads.length, counters[i].get());
        }
    }

}

The result on my 2-core machine is that value #-128 gets 146 hits (all of them should get 100 hits equally as we have 100 threads).

If anyone has any ideas, what's wrong with this code, I'm all ears/eyes.

UPDATE: for those who are in a hurry and do not want to scroll down, the correct (and shortest and most elegant) way to solve this in Java would be like this:

public byte nextValue() {
   return (byte) counter.incrementAndGet();
}

Thanks, Heinz!


Solution

  • You make the decision to incrementAndGet() based on a old value of counter.get(). The value of the counter can reach MAX_VALUE again before you do the incrementAndGet() operation on the counter.

    if (next > Byte.MAX_VALUE) {
        synchronized(counter) {
            int i = counter.get(); //here You make sure the the counter is not over the MAX_VALUE
            if (i > Byte.MAX_VALUE) {
                counter.set(INITIAL_VALUE);
                resetCounter.incrementAndGet();
                if (isSlow) slowDownAndLog(10, "resetting");
            } else {
                if (isSlow) slowDownAndLog(1, "missed"); //the counter can reach MAX_VALUE again if you wait here long enough
            }
            next = counter.incrementAndGet(); //here you increment on return the counter that can reach >MAX_VALUE in the meantime
        }
    }
    

    To make it work one has to make sure the no decisions are made on stale info. Either reset the counter or return the old value.

    public byte nextValue() {
        int next = counter.incrementAndGet();
    
        if (next > Byte.MAX_VALUE) {
            synchronized(counter) {
                next = counter.incrementAndGet();
                //if value is still larger than max byte value, we reset it
                if (next > Byte.MAX_VALUE) {
                    counter.set(INITIAL_VALUE + 1);
                    next = INITIAL_VALUE + 1;
                    resetCounter.incrementAndGet();
                    if (isSlow) slowDownAndLog(10, "resetting");
                } else {
                    if (isSlow) slowDownAndLog(1, "missed");
                }
            }
        }
    
        return (byte) next;
    }