Search code examples
javamultithreadingexecutorserviceproducer-consumerblockingqueue

Kill consumers when blockingqueue is empty


I'm reading up on blockingqueue, executoreserivce and the producer-consumer paradigm. I want to have a changing number of producers, and changing number of consumers. Each producer will append to the queue, and the consumers will consume the messages and process them. The question I have is - how do the producers know that the consumers are done, and no more messages will enter the queue? I thought to add a counter into my main thread. When a producer is started, I will increment the counter and that when each producer ends, they will decrement the int. My consumers will be able to know the counter, and when it reaches 0, and no more elements in the queue, they can die.

Another general question in terms of syncing the work - should the main thread read the contents of the queue, and add executers for each message, or is it best practice to have the threads know this logic and decide on their own when to die?

When the system starts up, I receive a number that decides how many producers will start. Each producer will generate a random set of numbers into the queue. The consumers will print these numbers to a log. The issue that I'm having is, that once I know that the last producer pushed the last number in, I still don't understand how to let the consumers know that there won't be any more numbers coming in, and they should shut down.

How do the consumers know when the producers are done?


Solution

  • One elegant solution to this problem is to use the PoisonPill pattern. Here is an example of how it works. All you need to know in this case, is the number of producers.

    Edit: I updated the code to clear the queue when last consumer finishes the work.

    import java.util.ArrayList;
    import java.util.List;
    import java.util.concurrent.BlockingQueue;
    import java.util.concurrent.CompletableFuture;
    import java.util.concurrent.LinkedBlockingQueue;
    import java.util.concurrent.atomic.AtomicInteger;
    
    public class PoisonPillsTests {
    
        interface Message {
    
        }
    
        interface PoisonPill extends Message {
            PoisonPill INSTANCE = new PoisonPill() {
            };
        }
    
        static class TextMessage implements Message {
    
            private final String text;
    
            public TextMessage(String text) {
                this.text = text;
            }
    
            public String getText() {
                return text;
            }
    
            @Override
            public String toString() {
                return text;
            }
        }
    
        static class Producer implements Runnable {
    
            private final String producerName;
            private final AtomicInteger producersCount;
            private final BlockingQueue<Message> messageBlockingQueue;
    
            public Producer(String producerName, BlockingQueue<Message> messageBlockingQueue, AtomicInteger producersCount) {
                this.producerName = producerName;
                this.messageBlockingQueue = messageBlockingQueue;
                this.producersCount = producersCount;
            }
    
            @Override
            public void run() {
                try {
                    for (int i = 0; i < 100; i++) {
                        messageBlockingQueue.put(new TextMessage("Producer " + producerName + " message " + i));
                    }
                    if (producersCount.decrementAndGet() <= 0) {
                        //we need this producersCount so that the producers to produce a single poison pill
                        messageBlockingQueue.put(PoisonPill.INSTANCE);
                    }
                } catch (InterruptedException e) {
                    throw new RuntimeException("Producer interrupted", e);
                }
            }
        }
    
        static class Consumer implements Runnable {
    
            private final AtomicInteger consumersCount;
            private final AtomicInteger consumedMessages;
            private final BlockingQueue<Message> messageBlockingQueue;
    
            public Consumer(BlockingQueue<Message> messageBlockingQueue, AtomicInteger consumersCount, AtomicInteger consumedMessages) {
                this.messageBlockingQueue = messageBlockingQueue;
                this.consumersCount = consumersCount;
                this.consumedMessages = consumedMessages;
            }
    
            @Override
            public void run() {
                try {
                    while (true) {
                        Message message = null;
                        message = messageBlockingQueue.take();
    
                        if (message instanceof PoisonPill) {
                            //we put back the poison pill so that to be consumed by the next consumer
                            messageBlockingQueue.put(message);
                            break;
                        } else {
                            consumedMessages.incrementAndGet();
                            System.out.println("Consumer got message " + message);
                        }
                    }
                } catch (InterruptedException e) {
                    throw new RuntimeException("Consumer interrupted", e);
                } finally {
                    if (consumersCount.decrementAndGet() <= 0) {
                        System.out.println("Last consumer, clearing the queue");
                        messageBlockingQueue.clear();
                    }
                }
            }
        }
    
        public static void main(String[] args) {
    
            final AtomicInteger producerCount = new AtomicInteger(4);
            final AtomicInteger consumersCount = new AtomicInteger(2);
            final AtomicInteger consumedMessages = new AtomicInteger();
            BlockingQueue<Message> messageBlockingQueue = new LinkedBlockingQueue<>();
    
    
            List<CompletableFuture<Void>> tasks = new ArrayList<>();
            for (int i = 0; i < producerCount.get(); i++) {
                tasks.add(CompletableFuture.runAsync(new Producer("" + (i + 1), messageBlockingQueue, producerCount)));
            }
    
            for (int i = 0; i < consumersCount.get(); i++) {
                tasks.add(CompletableFuture.runAsync(new Consumer(messageBlockingQueue, consumersCount, consumedMessages)));
            }
    
            CompletableFuture.allOf(tasks.toArray(new CompletableFuture[0])).join();
    
            System.out.println("Consumed " + consumedMessages + " messages");
    
        }
    }