Search code examples
javaniodisruptor-pattern

How to make a Java NIO (Non blocking IO) based TCP server using Disruptor?


I'm trying to implement a JAVA NIO based TCP server using Disruptor.

Java NIO works in a non-blocking fashion. All the new connections first hit the ServerAccept Socket. Then using the key (which is returned from selector.select()) method, appropriate handler (if the key is acceptable, a new socket channel is created, and the channel gets registered with the selector, if the key is readable, the content is read from the channel and then registered for writing, and if the key is writable, the channel is written whatever the response should have) is called. The most simple NIO based server works in a single thread (all handlers and selector in the same thread).

Java Disruptor is a high performing Ring implementation, which can be used to pass messages between different components (threads).

My questions are as follows.

  1. Can we use multiple threads for NIO design?

  2. Can we run the eventHandlers in separate threads?

  3. If we can run the eventHandlers in separate threads, how can we pass the selectionKeys and channels between threads?

  4. Can java Disruptor library be used for transferring data between main thread (in which selector runs) and eventHandler threads?

  5. If it is possible, what is the design approach? (What are the behaviours of EventProducer, EventConsumer and RingBuffer in Disruptor?)


Solution

  • You can make a NIO based server using any thread message passing method, where disruptor is one such option.

    There, the problem you need to address is how to share the work to a different thread (not to process the request in the main thread itself).

    Therefore, you can pass the buffer you get from the socket connection to a separate thread using a disruptor as the messaging method. Also, you need to maintain a shared concurrent hashmap to inform the main thread (that runs the event loop) whether the response is ready or not. Following is an example.

    HttpEvent.java

    import java.nio.ByteBuffer;
    
    public class HttpEvent
    {
        private ByteBuffer buffer;
        private String requestId;
        private int numRead;
    
    
        public ByteBuffer getBuffer() {
            return buffer;
        }
    
        public void setBuffer(ByteBuffer buffer) {
            this.buffer = buffer;
        }
    
        public String getRequestId() {
            return requestId;
        }
    
        public void setRequestId(String requestId) {
            this.requestId = requestId;
        }
    
        public int getNumRead() {
            return numRead;
        }
    
        public void setNumRead(int numRead) {
            this.numRead = numRead;
        }
    }
    

    HttpEventFactory.java

    import com.lmax.disruptor.EventFactory;
    
    public class HttpEventFactory implements EventFactory<HttpEvent>
    {
        public HttpEvent newInstance()
        {
            return new HttpEvent();
        }
    }
    

    HttpEventHandler.java

    import com.lmax.disruptor.EventHandler;
    
    import java.nio.ByteBuffer;
    import java.util.Dictionary;
    import java.util.concurrent.ConcurrentHashMap;
    
    public class HttpEventHandler implements EventHandler<HttpEvent>
    {
        private int id;
        private ConcurrentHashMap concurrentHashMap;
    
        public HttpEventHandler(int id, ConcurrentHashMap concurrentHashMap){
            this.id = id;
            this.concurrentHashMap = concurrentHashMap;
    
        }
    
        public void onEvent(HttpEvent event, long sequence, boolean endOfBatch) throws Exception
        {
            if( sequence % Runtime.getRuntime().availableProcessors()==id){
    
    
                String requestId = event.getRequestId();
                ByteBuffer buffer = event.getBuffer();
                int numRead= event.getNumRead();
    
                ByteBuffer responseBuffer = handleRequest(buffer, numRead);
    
    
                this.concurrentHashMap.put(requestId, responseBuffer);
    
            }
        }
    
        private ByteBuffer handleRequest(ByteBuffer buffer, int numRead) throws Exception {
    
            buffer.flip();
            byte[] data = new byte[numRead];
            System.arraycopy(buffer.array(), 0, data, 0, numRead);
            String request = new String(data, "US-ASCII");
            request = request.split("\n")[0].trim();
    
    
            String response = serverRequest(request);
    
            buffer.clear();
    
            buffer.put(response.getBytes());
            return  buffer;
        }
    
        private String serverRequest(String request) throws Exception {
            String response = "Sample Response";
            if (request.startsWith("GET")) {
    
                // http request parsing and response generation should be done here.    
    
    
            return  response;
        }
    }
    

    HttpEventMain.java

    import com.lmax.disruptor.RingBuffer;
    import com.lmax.disruptor.dsl.Disruptor;
    import org.apache.commons.lang3.RandomStringUtils;
    
    import java.io.IOException;
    import java.net.*;
    import java.nio.ByteBuffer;
    import java.nio.channels.SelectionKey;
    import java.nio.channels.Selector;
    import java.nio.channels.ServerSocketChannel;
    import java.nio.channels.SocketChannel;
    import java.util.Iterator;
    import java.util.concurrent.ConcurrentHashMap;
    import java.util.concurrent.Executor;
    import java.util.concurrent.Executors;
    
    public class HttpEventMain
    {
        private InetAddress addr;
        private int port;
        private Selector selector;
        private HttpEventProducer producer ;
        private ConcurrentHashMap concurrentHashMapResponse;
        private ConcurrentHashMap concurrentHashMapKey;
    
        public HttpEventMain(InetAddress addr, int port) throws IOException {
            this.setAddr(addr);
            this.setPort(port);
            this.setConcurrentHashMapResponse(new ConcurrentHashMap<>());
            this.concurrentHashMapKey = new ConcurrentHashMap<>();
        }
    
    
        public static void main(String[] args) throws Exception
        {
            System.out.println("----- Running the server on machine with "+Runtime.getRuntime().availableProcessors()+" cores -----");
    
            HttpEventMain server = new HttpEventMain(null, 4333);
    
    
    
            HttpEventFactory factory = new HttpEventFactory();
    
    
            int bufferSize = 1024;
    
    
            Executor executor = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors()); // a thread pool to which we can assign tasks
    
    
            Disruptor<HttpEvent> disruptor = new Disruptor<HttpEvent>(factory, bufferSize, executor);
    
            HttpEventHandler [] handlers = new HttpEventHandler[Runtime.getRuntime().availableProcessors()];
    
            for(int i = 0; i<Runtime.getRuntime().availableProcessors();i++){
                handlers[i] = new HttpEventHandler(i, server.getConcurrentHashMapResponse());
            }
    
    
            disruptor.handleEventsWith(handlers);
    
    
    
    
            disruptor.start();
    
    
            RingBuffer<HttpEvent> ringBuffer = disruptor.getRingBuffer();
    
            server.setProducer(new HttpEventProducer(ringBuffer, server.getConcurrentHashMapResponse()));
    
            try {
                System.out.println("\n====================Server Details====================");
                System.out.println("Server Machine: "+ InetAddress.getLocalHost().getCanonicalHostName());
                System.out.println("Port number: " + server.getPort());
    
            } catch (UnknownHostException e1) {
                e1.printStackTrace();
            }
    
            try {
    
                server.start();
    
            } catch (IOException e) {
                System.err.println("Error occured in HttpEventMain:" + e.getMessage());
                System.exit(0);
            }
    
    
    
        }
        private void start() throws IOException {
            this.selector = Selector.open();
            ServerSocketChannel serverChannel = ServerSocketChannel.open();
            serverChannel.configureBlocking(false);
    
    
            InetSocketAddress listenAddr = new InetSocketAddress(this.addr, this.port);
            serverChannel.socket().bind(listenAddr);
            serverChannel.register(this.selector, SelectionKey.OP_ACCEPT);
    
            System.out.println("Server ready. Ctrl-C to stop.");
    
            while (true) {
    
                this.selector.select();
    
    
                Iterator keys = this.selector.selectedKeys().iterator();
                while (keys.hasNext()) {
                    SelectionKey key = (SelectionKey) keys.next();
    
                    keys.remove();
    
                    if (! key.isValid()) {
                        continue;
                    }
    
                    if (key.isAcceptable()) {
                        this.accept(key);
                    }
                    else if (key.isReadable()) {
                        this.read(key);
                    }
                    else if (key.isWritable()) {
                        this.write(key);
                    }
                }
            }
    
        }
    
        private void accept(SelectionKey key) throws IOException {
    
            ServerSocketChannel serverChannel = (ServerSocketChannel) key.channel();
            SocketChannel channel = serverChannel.accept();
            channel.configureBlocking(false);
    
    
            Socket socket = channel.socket();
            SocketAddress remoteAddr = socket.getRemoteSocketAddress();
    
            channel.register(this.selector, SelectionKey.OP_READ);
        }
    
        private void read(SelectionKey key) throws IOException {
    
            SocketChannel channel = (SocketChannel) key.channel();
    
            ByteBuffer buffer = ByteBuffer.allocate(8192);
            int numRead = -1;
            try {
                numRead = channel.read(buffer);
            }
            catch (IOException e) {
                e.printStackTrace();
            }
    
            if (numRead == -1) {
    
                Socket socket = channel.socket();
                SocketAddress remoteAddr = socket.getRemoteSocketAddress();
                channel.close();
                key.cancel();
                return;
    
            }
            String requestID = RandomStringUtils.random(15, true, true);
    
            while(concurrentHashMapKey.containsValue(requestID) || concurrentHashMapResponse.containsKey(requestID)){
                requestID = RandomStringUtils.random(15, true, true);
            }
    
            concurrentHashMapKey.put(key, requestID);
    
            this.producer.onData(requestID, buffer, numRead);
    
            channel.register(this.selector, SelectionKey.OP_WRITE, buffer);
        }
    
        private boolean responseReady(SelectionKey key){
    
            String requestId = concurrentHashMapKey.get(key).toString();
            String response = concurrentHashMapResponse.get(requestId).toString();
    
            if(response!="0"){
                concurrentHashMapKey.remove(key);
                concurrentHashMapResponse.remove(requestId);
                return true;
            }else{
                return false;
            }
    
        }
    
        private void write(SelectionKey key) throws IOException {
    
            if(responseReady(key)) {
                SocketChannel channel = (SocketChannel) key.channel();
                ByteBuffer inputBuffer = (ByteBuffer) key.attachment();
                inputBuffer.flip();
                channel.write(inputBuffer);
                channel.close();
                key.cancel();
    
            }
    
        }
    
        public HttpEventProducer getProducer() {
            return producer;
        }
    
        public void setProducer(HttpEventProducer producer) {
            this.producer = producer;
        }
    
        public ConcurrentHashMap getConcurrentHashMapResponse() {
            return concurrentHashMapResponse;
        }
    
        public void setConcurrentHashMapResponse(ConcurrentHashMap concurrentHashMapResponse) {
            this.concurrentHashMapResponse = concurrentHashMapResponse;
        }
    
        public InetAddress getAddr() {
            return addr;
        }
    
        public void setAddr(InetAddress addr) {
            this.addr = addr;
        }
    
        public int getPort() {
            return port;
        }
    
        public void setPort(int port) {
            this.port = port;
        }
    
        public Selector getSelector() {
            return selector;
        }
    
        public void setSelector(Selector selector) {
            this.selector = selector;
        }
    }
    

    HttpEventProducer.java

    import com.lmax.disruptor.RingBuffer;
    
    import java.nio.ByteBuffer;
    import java.util.concurrent.ConcurrentHashMap;
    
    public class HttpEventProducer
    {
        private final RingBuffer<HttpEvent> ringBuffer;
        private final ConcurrentHashMap concurrentHashMap;
    
        public HttpEventProducer(RingBuffer<HttpEvent> ringBuffer, ConcurrentHashMap concurrentHashMap)
        {
            this.ringBuffer = ringBuffer;
            this.concurrentHashMap = concurrentHashMap;
        }
    
        public void onData(String requestId, ByteBuffer buffer, int numRead)
        {
            long sequence = ringBuffer.next();
    
            try
            {
                HttpEvent event = ringBuffer.get(sequence);
                event.setBuffer(buffer);
                event.setRequestId(requestId);
                event.setNumRead(numRead);
            }
            finally
            {
                concurrentHashMap.put(requestId, "0");
                ringBuffer.publish(sequence);
    
    
            }
        }
    }