Search code examples
javaperformancedistributed-computing

Load Balance Queues for variable rate consumers


I have a producer and consumer framework. Each producer pushes to a queue and consumer consumes from the queue. At any point in time there can be one or more queues with each consumer consuming from single queue. But producer can produce to any queue. If a consumer is slow its keeps piling with messages. I am trying to come with a framework where i can load balance consumers so that all consumer queues have almost equal messages regardless of the consumers speed.

Example:

enter image description here

Here queues Q1-Q3 is supposed to have almost equal messages irrespective of the rate of C1-C3 consumers. Default policy now i am using is round robin for producers but if any consumer is slow it keep adding messages to queue. All messages are of the same type so it go to any of the queues.

Any suggestions to start with is helpful.


Solution

  • Below is my solution i have implemented. Algorithm used is as below.

    1. Every 30 seconds find mean of all queues.
    2. If the lag of a consumer w.r.t mean in greater than a particular threshold ignore that queue/consumer.

    Producer Code:

    import java.util.ArrayList;
    import java.util.List;
    import java.util.Random;
    import java.util.concurrent.BlockingQueue;
    
    public class Producer implements Runnable{
    
        private List<BlockingQueue<Integer>> blockingQueues = new ArrayList<>();
        private List<Integer> fullPartitions;
        private List<Integer> activePartitions;
        long timer = System.currentTimeMillis();
        int THRESHOLD = 10000;
        int currentQueue = 0;
    
        public Producer(List<BlockingQueue<Integer>> blockingQueues, List<Integer> fullPartitions, List<Integer> activePartitions) {
            this.blockingQueues = blockingQueues;
            this.fullPartitions = fullPartitions;
            this.activePartitions = activePartitions;
        }
    
        @Override
        public void run() {
            long start = System.currentTimeMillis();
            while(true) {
                blockingQueues.get(getNextID()).offer(new Random().nextInt(100000));
                try {
                    if(System.currentTimeMillis()-start<300000)
                        Thread.sleep(1);
                    else
                        break;
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
            }
        }
    
        private int getNextID() {
            if(System.currentTimeMillis()-timer>30000) {
                activePartitions = new ArrayList<>();
                long mean = 0l; 
                for(int i=0;i<fullPartitions.size();i++) 
                    mean += blockingQueues.get(i).size();
    
                mean  = mean/blockingQueues.size();
                for(int i=0;i<fullPartitions.size();i++) 
                    if(blockingQueues.get(i).size()-mean<THRESHOLD)
                        activePartitions.add(i);
    
                timer = System.currentTimeMillis();
            }
            int partitionID = activePartitions.get(currentQueue%activePartitions.size());
            currentQueue++;
            return partitionID;
        }
    }
    

    Consumer :

    import java.util.concurrent.ArrayBlockingQueue;
    import java.util.concurrent.BlockingQueue;
    
    public class Consumer implements Runnable{
    
        private BlockingQueue<Integer> blockingQueue = new ArrayBlockingQueue<>(100000000);
        private int delayFactor;
        public Consumer(BlockingQueue<Integer> blockingQueue, int delayFactor, int consumerNo) {
            this.blockingQueue = blockingQueue;
            this.delayFactor = delayFactor;
        }
    
        @Override
        public void run() {
            long start = System.currentTimeMillis();
            while(true) {
                try {
                    blockingQueue.take();
                    if(blockingQueue.isEmpty())
                        System.out.println((System.currentTimeMillis()-start)/1000);
                    Thread.sleep(delayFactor);
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
            }
        }
    
    }
    

    Main Thread:

    import java.util.ArrayList;
    import java.util.List;
    import java.util.concurrent.ArrayBlockingQueue;
    import java.util.concurrent.BlockingQueue;
    
    public class KafkaLoadBalancer {
    
        private static int MAX_PARTITION = 4;
    
        public static void main(String args[]) throws InterruptedException {
            List<BlockingQueue<Integer>> blockingQueues = new ArrayList<>();
            List<Integer> fullPartitions = new ArrayList<Integer>();
            List<Integer> activePartitions = new ArrayList<Integer>();
    
            System.out.println("Creating Queues");
            for(int i=0;i<MAX_PARTITION;i++) {
                blockingQueues.add(new ArrayBlockingQueue<>(1000000));
                fullPartitions.add(i);
                activePartitions.add(i);
            }
    
            System.out.println("Starting Producers");
            for(int i=0;i<MAX_PARTITION;i++) {
                Producer producer = new Producer(blockingQueues,fullPartitions,activePartitions);
                new Thread(producer).start();
            }
    
            System.out.println("Starting Consumers");
            for(int i=0;i<MAX_PARTITION;i++) {
                Consumer consumer = new Consumer(blockingQueues.get(i),i+1,i);
                new Thread(consumer).start();
            }
    
            System.out.println("Starting Display Thread");
            DisplayQueue dq = new DisplayQueue(blockingQueues);
            new Thread(dq).start();
        }
    }
    

    DispayQueue : To display queue size

    import java.util.List;
    import java.util.concurrent.BlockingQueue;
    
    public class DisplayQueue implements Runnable {
    
        private List<BlockingQueue<Integer>> blockingQueues;
    
        public DisplayQueue(List<BlockingQueue<Integer>> blockingQueues) {
            this.blockingQueues = blockingQueues;
        }
    
        @Override
        public void run() {
    
            long start = System.currentTimeMillis();
            while(true) {
                if(System.currentTimeMillis()-start>30000) {
                    for(int i=0;i<blockingQueues.size();i++)
                        System.out.println("Queue "+i+" size is=="+blockingQueues.get(i).size());
                    start = System.currentTimeMillis();
                }
            }
    
        }
    
    }