Search code examples
javamultithreadingmavensocketsrunnable

Multi thread socket server shutdown by message Java


I have a multi threaded process that listens for messages and processes them. I want to be able to shut down the process if one of the messages received is "shutdown". I have implemented everything except the shutdown part.

I have a "Multi" class that extends java.net.ServerSocket with a startmethod. Inside...

java.net.Socket socket = null;
while (true) {
            try {
                socket = this.accept();
                new Thread(new SocketThread(socket, verifier, threading)).start();

            } catch (IOException e) {
                e.printStackTrace();
            }
        }

That SocketThread is another class that implements Runnable. Is there any way I can make this work?


Solution

  • I haven't had to do this before (as I don't usually find myself having to write a raw SocketServer much), but with this type of threading coordination I would try something like the following...

    public class App {
        public static void main(String[] args) {
            new App().run();
        }
    
        public void run() {
            try {
                System.out.println("Starting...");
                AtomicBoolean running = new AtomicBoolean(true);
                Collection<Socket> sockets = new ArrayList<>();
                Collection<Thread> threads = new ArrayList<>();
                try (ServerSocket socketServer = new ServerSocket(10101)) {
                    System.out.println("Started.");
                    while (running.get()) {
                        Socket socket = socketServer.accept();
                        sockets.add(socket);
                        if (running.get()) {
                            Thread thread = new Thread(new SocketHandler(socket, running));
                            thread.start();
                            threads.add(thread);
                        }
                    }
                    System.out.println("Stopping...");
                    sockets.forEach(socket -> {
                        try {
                            socket.close();
                        } catch (IOException e) {
                            e.printStackTrace();
                        }
                    });
                    threads.forEach(thread -> {
                        try {
                            thread.join();
                        } catch (InterruptedException e) {
                            e.printStackTrace();
                        }
                    });
                }
                System.out.println("Stopped.");
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
    
        static class SocketHandler implements Runnable {
            private final Socket socket;
            private final AtomicBoolean running;
    
            SocketHandler(Socket socket, AtomicBoolean running) {
                this.socket = socket;
                this.running = running;
            }
    
            @Override
            public void run() {
                try {
                    System.out.println("Client connected.");
                    try (BufferedReader in = new BufferedReader(new InputStreamReader(socket.getInputStream()))) {
                        boolean connected = true;
                        while (connected){
                            String command = in.readLine();
                            System.out.println("Command received: " + command);
                            if (command == null) {
                                connected = false;
                            } else if (command.equals("shutdown")) {
                                running.set(false);
                                try (Socket tmpSocket = new Socket("localhost", 10101)) {}
                            }
                            // process other commands
                        }
                    }
                    System.out.println("Client disconnected.");
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
        }
    }
    

    Update: Changed the example to establish a connection to the server to cause it to unblock.

    Update: Example code to handle the case where client disconnects. Credit to @user207421 for highlighting this (thanks).

    Update: Changed the sample code to handle multiple client sockets / threads. Note that it will through an exception when you socket is closed which is currently just been printed to stderr. You'll probably want to handle that differently.

    Update: You might also find code that simulates multiple client connections helpful:

    public class Clients {
        public static void main(String[] args) {
            Thread thread1 = new Thread(() -> new NormalClient().run());
            Thread thread2 = new Thread(() -> new NormalClient().run());
            Thread thread3 = new Thread(() -> new NormalClient().run());
            Thread thread4 = new Thread(() -> new NormalClient().run());
            Thread thread5 = new Thread(() -> new ShutdownClient().run());
            thread1.start();
            thread2.start();
            thread3.start();
            thread4.start();
            thread5.start();
        }
    }
    
    class NormalClient {
        void run() {
            try {
                try (Socket socket = new Socket("localhost", 10101);
                     BufferedWriter out = new BufferedWriter(new OutputStreamWriter(socket.getOutputStream()))) {
    
                    for (int i = 0; i < 10; i++) {
                        out.write("hello " + i);
                        out.newLine();
                        out.flush();
                        sleep(1000);
                    }
                }
            } catch (IOException | InterruptedException e) {
                e.printStackTrace();
            }
        }
    }
    
    class ShutdownClient {
        void run() {
            try {
                try (Socket socket = new Socket("localhost", 10101);
                     BufferedWriter out = new BufferedWriter(new OutputStreamWriter(socket.getOutputStream()))) {
    
                    sleep(8000);
                    out.write("shutdown");
                    out.newLine();
                    out.flush();
                }
            } catch (IOException | InterruptedException e) {
                e.printStackTrace();
            }
        }
    }