Search code examples
javamultithreadingconcurrencythread-safetyvolatile

Volatile doesn't work as expected


So I'm reading Brian Goetz' JCIP and wrote a following code for experimenting with volatile behavior.

public class StatefulObject {

    private static final int NUMBER_OF_THREADS = 10;

    private volatile State state;

    public StatefulObject() {
        state = new State();
    }

    public State getState() {
        return state;
    }

    public void setState(State state) {
        this.state = state;
    }

    public static class State {
        private volatile AtomicInteger counter;

        public State() {
            counter = new AtomicInteger();
        }

        public AtomicInteger getCounter() {
            return counter;
        }

        public void setCounter(AtomicInteger counter) {
            this.counter = counter;
        }
    }

    public static void main(String[] args) throws InterruptedException {
        StatefulObject object = new StatefulObject();

        ExecutorService executorService = Executors.newFixedThreadPool(NUMBER_OF_THREADS);

        AtomicInteger oldCounter = new AtomicInteger();
        AtomicInteger newCounter = new AtomicInteger();

        object.getState().setCounter(oldCounter);

        ConcurrentMap<Integer, Long> lastSeen = new ConcurrentHashMap<>();
        ConcurrentMap<Integer, Long> firstSeen = new ConcurrentHashMap<>();
        lastSeen.put(oldCounter.hashCode(), 0L);
        firstSeen.put(newCounter.hashCode(), Long.MAX_VALUE);

        List<Future> futures = IntStream.range(0, NUMBER_OF_THREADS)
            .mapToObj(num -> executorService.submit(() -> {
                for (int i = 0; i < 1000; i++) {
                    object.getState().getCounter().incrementAndGet();
                    lastSeen.computeIfPresent(object.getState().getCounter().hashCode(), (key, oldValue) -> Math.max(oldValue, System.nanoTime()));
                    firstSeen.computeIfPresent(object.getState().getCounter().hashCode(), (key, oldValue) -> Math.min(oldValue, System.nanoTime()));
                }
            })).collect(Collectors.toList());

        executorService.shutdown();

        object.getState().setCounter(newCounter);

        futures.forEach(future -> {
            try {
                future.get();
            } catch (InterruptedException e) {
                e.printStackTrace();
            } catch (ExecutionException e) {
                e.printStackTrace();
            }
        });

        System.out.printf("Counter: %s\n", object.getState().getCounter().get());
        long lastSeenOld = lastSeen.get(oldCounter.hashCode());
        long firstSeenNew = firstSeen.get(newCounter.hashCode());
        System.out.printf("Last seen old counter: %s\n", lastSeenOld);
        System.out.printf("First seen new counter: %s\n", firstSeenNew);
        System.out.printf("Old was seen after the new: %s\n", lastSeenOld > firstSeenNew);
        System.out.printf("Old was seen %s nanoseconds after the new\n", lastSeenOld - firstSeenNew);
    }
}

So I'm expecting that newCounter is always first seen only after oldCounter was last seen (I expect all threads to notice the update so none is referencing the stale counter). To observe this behavior I use two maps. But surprisingly, I constantly get output like this:

Counter: 9917
Last seen old counter: 695372684800871
First seen new counter: 695372684441226
Old was seen after the update: true
Old was seen 359645 nanoseconds after the new

Can you please explain where I'm wrong?

Thanks in advance!


Solution

  • The reason behind your observation is not a bug in java ;) but there is one in your code. In your code you cannot guarantee that invocation of computeIfPresent for lastseen and firstSeen maps executed atomically (refer to the Javadocs, computeIfPresent is not atomic). What this means is there is time gap between when you gets object.getState().getCounter() and actually updates the map.

    If setting newCounter happens while thread A in this gap (before getting the nanotime but already got the counter reference - old) and Thread B at just before getting object.getState().getCounter(). So if this exact moment counter reference got updated, Thread A will update the old counter key while Thread B will update the new. If Thread B took nanotime before Thread A (this could happen because these are separated threads which we cannot know what are the actual cpu scheduling) that could perfectly lead to your observation.

    I think my explanation is clear. One more thing to clarify, in State class, you have declared AtomicInteger counter as volatile as well. This is not needed since a AtomicInteger is inherently is volatile. There are no "non-volatile" Atomic** s.

    I just changed few things in your code to omit the above mentioned issues :

    import java.util.Collections;
    import java.util.List;
    import java.util.concurrent.*;
    import java.util.concurrent.atomic.AtomicInteger;
    import java.util.stream.Collectors;
    import java.util.stream.IntStream;
    
    public class StatefulObject {
    
        private static final int NUMBER_OF_THREADS = 10;
    
        private volatile State state;
    
        public StatefulObject() {
            state = new State();
        }
    
        public State getState() {
            return state;
        }
    
        public void setState(State state) {
            this.state = state;
        }
    
        public static class State {
            private volatile AtomicInteger counter;
    
            public State() {
                counter = new AtomicInteger();
            }
    
            public AtomicInteger getCounter() {
                return counter;
            }
    
            public void setCounter(AtomicInteger counter) {
                this.counter = counter;
            }
        }
    
        public static void main(String[] args) throws InterruptedException {
            StatefulObject object = new StatefulObject();
    
            ExecutorService executorService = Executors.newFixedThreadPool(NUMBER_OF_THREADS);
    
            AtomicInteger oldCounter = new AtomicInteger();
            AtomicInteger newCounter = new AtomicInteger();
    
            object.getState().setCounter(oldCounter);
    
            List<Long> oldList = new CopyOnWriteArrayList<>();
            List<Long> newList = new CopyOnWriteArrayList<>();
    
            List<Future> futures = IntStream.range(0, NUMBER_OF_THREADS)
                .mapToObj(num -> executorService.submit(() -> {
                    for (int i = 0; i < 1000; i++) {
                        long l = System.nanoTime();
                        object.getState().getCounter().incrementAndGet();
                        if (object.getState().getCounter().equals(oldCounter)) {
                            oldList.add(l);
                        } else {
                            newList.add(l);
                        }
                    }
                })).collect(Collectors.toList());
    
            executorService.shutdown();
    
            object.getState().setCounter(newCounter);
    
            futures.forEach(future -> {
                try {
                    future.get();
                } catch (InterruptedException e) {
                    e.printStackTrace();
                } catch (ExecutionException e) {
                    e.printStackTrace();
                }
            });
    
            System.out.printf("Counter: %s\n", object.getState().getCounter().get());
            Collections.sort(oldList);
            Collections.sort(newList);
            long lastSeenOld = oldList.get(oldList.size() - 1);
            long firstSeenNew = newList.get(0);
            System.out.printf("Last seen old counter: %s\n", lastSeenOld);
            System.out.printf("First seen new counter: %s\n", firstSeenNew);
            System.out.printf("Old was seen after the new: %s\n", lastSeenOld > firstSeenNew);
            System.out.printf("Old was seen %s nanoseconds after the new\n", lastSeenOld - firstSeenNew);
        }
    }