Search code examples
javamultithreadingthreadpoolselectornio

Cannot read the object while using thread pools to process


I have to design a server with 3 main thread pools in order to read the data, process them and output the results to the client. I code like this but it always notice this kind of bug:

java.io.StreamCorruptedException: invalid stream header: 00000000 at java.base/java.io.ObjectInputStream.readStreamHeader(ObjectInputStream.java:958) at java.base/java.io.ObjectInputStream.(ObjectInputStream.java:392) at demo.NioServer.read(NioServer.java:99) at java.base/java.util.concurrent.Executors$RunnableAdapter.call(Executors.java:539) at java.base/java.util.concurrent.FutureTask.run(FutureTask.java:264) at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1136) at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:635) at java.base/java.lang.Thread.run(Thread.java:840)

I dont know why and here is my code:

package demo;


import java.io.*;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.nio.ByteBuffer;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.util.Iterator;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;

public class NioServer {
    private ServerSocketChannel server;
    private Selector selector;
    private ByteBuffer data;
    private ByteArrayInputStream bais;
    private ByteArrayOutputStream baos;
    private ObjectInputStream in;
    private ObjectOutputStream out;
    private final ExecutorService input = Executors.newCachedThreadPool();
    private final ExecutorService processor = Executors.newFixedThreadPool(5);
    private final ExecutorService output = Executors.newCachedThreadPool();
    private final BlockingQueue<SelectionKey> keyManager = new LinkedBlockingQueue<>(5);
    private final BlockingQueue<SelectionKey> destination = new LinkedBlockingQueue<>(5);
    private final BlockingQueue<Request> requests = new LinkedBlockingQueue<>(5);
    private final BlockingQueue<Response> responses = new LinkedBlockingQueue<>(5);
    public NioServer() throws IOException {
        server = null;
        selector = null;
    }

    private void init() throws IOException {
        server = ServerSocketChannel.open();
        selector = Selector.open();
        server.socket().bind(new InetSocketAddress("127.0.0.1", 4999));
        server.configureBlocking(false);
        server.register(selector, SelectionKey.OP_ACCEPT);
    }

    public void run() throws IOException {
        init();
        try {
            while (true) {
                selector.select();
                for(SelectionKey key : selector.selectedKeys()) {
                    if(key.isAcceptable()) {
                        accept(key);
                    } else if(key.isReadable()) {
                        keyManager.put(key);
                        input.submit(this::read);
                        processor.submit(this::process);
                        output.submit(this::write);
                    }
                }
                selector.selectedKeys().clear();
            }
        } catch(IOException | InterruptedException e) {
            e.printStackTrace();
        } finally {
            for(SelectionKey key : selector.selectedKeys()) {
                key.channel().close();
            }
        }
    }

    private void accept(SelectionKey key) throws IOException {
        ServerSocketChannel serverSocketChannel = (ServerSocketChannel) key.channel();
        SocketChannel socketChannel = serverSocketChannel.accept();
        socketChannel.configureBlocking(false);
        socketChannel.register(this.selector, SelectionKey.OP_READ);
    }

    private void read() {
        try {
            SelectionKey key = keyManager.take();
            SocketChannel client = (SocketChannel)key.channel();
            data = ByteBuffer.allocate(1024);
            int numRead = -1;
            try {
                numRead = client.read(data);
            } catch (IOException e) {
                key.cancel();
                client.close();
                return;
            }
            if(numRead == -1) {
                client.close();
                key.cancel();
                return;
            }
            data.flip();
            bais = new ByteArrayInputStream(data.array());
            in = new ObjectInputStream(bais);
            requests.put((Request)in.readObject());
            destination.put(key);
        } catch (IOException | InterruptedException | ClassNotFoundException e) {
            e.printStackTrace();
        }
    }

    private void process() {
        try {
            Request request = requests.take();
            responses.put(new Response(request.getInfo() + " enrolled"));
        } catch(InterruptedException e) {
            System.out.println(e.toString());
        }
    }

    private void write() {
        try {
            Response response = responses.take();
            SelectionKey key = destination.take();
            SocketChannel client = (SocketChannel)key.channel();
            baos = new ByteArrayOutputStream();
            out = new ObjectOutputStream(baos);
            out.writeObject(response);
            out.flush();
            client.write(ByteBuffer.wrap(baos.toByteArray()));
            client.close();
        } catch (InterruptedException | IOException e) {
            System.out.println(e.toString());
        }
    }

}




The problem is at the line: in = new ObjectInputStream(bais); I tried to fix this and realized that if I define the fixed thread pool with size n, after n times of giving the correct results, this bug will happen. Sometimes, the bug appeared earlier.

Can someone explain to me why this happened and if you can, can you suggest me some solution? Thank you so much!


Solution

  • Oh, boy, where do we start.

    I hope you're doing it in learning purpose, because creating correct NIO server is pretty hard. So better use existing. Try netti/jetty, it is well known library.

    And now lets dissect problems with your code:

    1. Do not share ByteBuffer data variable (you doing it wrong anyway, you access it in not thread safe way).
    2. TCP is stream protocol. There is no guarantee that client.read(data) read one whole message. It can read part of it. It can read two. It can read second part of first message and first part of second.
    3. You need to clear SelectionKey.OP_READ before you start reading and set it back after you finished. Otherwise two thread may read from same channel.
    4. You can't just out.writeObject(response), you need to wait until channel is ready.

    This is what I spot from first glance, there are probably more.

    For full list of problem you should post your code in CodeReview

    If you interested in working example you may try look into netty source code, but it is super complicated.

    I have pet project with simple version of NIO connector, but it far from perfect. For example it have third problem from list above.