Search code examples
javaconcurrencyjava.util.concurrentconcurrenthashmap

Best way to implement conditional thread barrier


The task looks like this. We have a bunch of threads that run some method. Its ok to run it concurrently, but only if some conditions are fullfilled. If not - a thread should wait.

Here is an example of what I am talking about.

We have a time-consuming method doStuff which takes a Key instance as a parameter. It works ok in multithreaded envoronment only if keys provided are not equal. It fails if two equal keys are processed simultaneously. We need to write a code that stops the threads with equal keys from calling this method same time. I have written three implementations: via ConcurrentHashMap with these keys, via AtomicIntegerArray of key indexes and via simple synchronized block which examines the set of keys under process.

public class KeyProblem {

    static class Key {

        private int index;

        Key() {
            this.index = (int) (Math.random() * 10) % 10;
        }

        public int getIndex() {
            return index;
        }

        @Override
        public String toString() {
            return "Key{" +
                    "index=" + index +
                    '}';
        }

        @Override
        public boolean equals(Object o) {
            if (this == o) return true;
            if (o == null || getClass() != o.getClass()) return false;

            Key key = (Key) o;

            if (index != key.index) return false;

            return true;
        }

        @Override
        public int hashCode() {
            return index;
        }
    }

    private static ConcurrentHashMap<Key, Object> keysProcessedRightNowForCheck = new ConcurrentHashMap<Key, Object>();

    public static void doStuff(Key key) {
        Object sentinel = keysProcessedRightNowForCheck.putIfAbsent(key, new Object());
        if (sentinel != null) {
            System.out.println("ERROR! Equal keys! " + key + " " + Thread.currentThread().getName());
        }
        try {
            System.out.println(String.format("   started by %s with %s ", Thread.currentThread().getName(), key));
            Thread.sleep(500);
            System.out.println(String.format("   finished by %s with %s ", Thread.currentThread().getName(), key));
        } catch (InterruptedException e) {
        }
        keysProcessedRightNowForCheck.remove(key);
    }

    //first version: via ConcurrentHashMap

    private static ConcurrentHashMap map = new ConcurrentHashMap();
    private static Object waiter = new Object();

    public static void viaConcurrentHashMap(Key key) throws InterruptedException {
        while (map.putIfAbsent(key, new Object()) != null) {
            synchronized (waiter) {
                System.out.println("wait with key " + key);
                waiter.wait();
                System.out.println("done waiting with key " + key);
            }
        }
        System.out.println("started stuff with " + key);
        doStuff(key);
        map.remove(key);
        synchronized (waiter) {
            System.out.println("notified after stuff with " + key);
            waiter.notifyAll();
            System.out.println("done waiting with key " + key);
        }
    }

    //second version: via AtomicIntegerArray for a fixed number of keys

    private static AtomicIntegerArray keyProcessed = new AtomicIntegerArray(10);

    public static void viaAtomicIntegerArray(Key key) throws InterruptedException {

        while (!keyProcessed.compareAndSet(key.getIndex(), 0, 1)) {
            synchronized (waiter) {
                System.out.println("wait with key " + key);
                waiter.wait();
            }
        }

        doStuff(key);
        keyProcessed.decrementAndGet(key.getIndex());

        synchronized (waiter) {
            System.out.println("notified after stuff with " + key);
            waiter.notifyAll();
        }

    }

    //third version: via a simple lock

    private static Object lock = new Object();
    private static Set<Key> keys = new HashSet<Key>();

    public static void viaSimpleSynchronized(Key key) throws InterruptedException {
        synchronized (lock) {
            while (keys.contains(key)) {
                lock.wait();
            }
            keys.add(key);
        }

        doStuff(key);
        synchronized (lock) {
            keys.remove(key);
            lock.notifyAll();
        }
    }

    private static CyclicBarrier barrier;

    public static void main(String[] args) throws InterruptedException, BrokenBarrierException {

        final int MAX = 100;

        List<Key> keys = new ArrayList<Key>() {{
            for (int i = 0; i < MAX; i++) add(new Key());
        }};

        barrier = new CyclicBarrier(MAX + 1);

        long start = System.currentTimeMillis();

        for (final Key key : keys) {
            Thread t = new Thread() {
                public void run() {
                    try {
//                        viaConcurrentHashMap(key);
                        viaSimpleSynchronized(key);
//                        viaAtomicIntegerArray(key);
                        barrier.await();
                    } catch (InterruptedException e) {
                    } catch (BrokenBarrierException e) {
                        e.printStackTrace();  //To change body of catch statement use File | Settings | File Templates.
                    }

                }
            };
            t.start();
        }

        barrier.await();
        System.out.println("system time [ms] " + (System.currentTimeMillis() - start));
        //7 for array
    }
}

For 100 threads the running time is around 7s, for the third version slightly slower as expected.

Finally, the questions are:

1) Is my code correct and a thread-safe one?

2) Can you suggest a better implementation?

3) Is there some classes in java.util.concurrent that solve this task in a generalized way? I mean a kind of barrier which lets the thread go only if some condition is fulfilled.


Solution

  • If you can afford to hold all the keys in memory at once, and as singletons (at least that's how your example works), then it seems that a very simple solution would be as follows:

    1. fetch the appropriate key;
    2. execute the logic inside a synchronized (key) {} block.