Search code examples
javamultithreadingsocketsserverjunit4

How do I test a multi-threading server


I have a main thread server which basically listen to whoever want to connect to the port

/**
 * The main server thread receives request from and sends
 * response to clients.
 */
public class Server {

  /*
  Port number for the client connection
   */
  private static final int PORT = 3000;

  /*
  The number of client can be connected
   */
  private static final int SIZE = 10;

  /*
  The executor
   */
  private static ExecutorService executorService = Executors.newFixedThreadPool(SIZE);



  /**
   * Starts the main server.
   */
  public static void main(String[] args) {
    /*
    All the information are stored into the queue
    */
    BlockingQueue<Message> messageQueue = new LinkedBlockingQueue<>();
    /*
    All the clients are stored into the map
    */
    ConcurrentMap<byte[], Boolean> clientManagement = new ConcurrentHashMap<>();
    runMainServer(messageQueue, clientManagement);

  }

  private static void runMainServer(BlockingQueue<Message> messageQueue, ConcurrentMap<byte[], Boolean> clientManagement) {
    try (
        ServerSocket serverSocket = new ServerSocket(PORT);
    ) {
      System.out.println("Starting server");
      while (true) {
        System.out.println("Waiting for request");
        Socket socket = serverSocket.accept();
        System.out.println("Processing request");
        ClientThread newClient = new ClientThread(socket, messageQueue, clientManagement);
        executorService.submit(newClient);
      }
    } catch (IOException e) {
      e.printStackTrace();
    }
  }

}

And I have many multi-threading sub-server to handle each identical client. The client is first going to be accepted by the server, and checked if first message that the server received is a connect_message. If it is, then they are officially connected. There are many more message other than connect_message. But I am just not going to be too specific on them.

/**
 * The client thread.
 */
public class ClientThread implements Runnable {

  private Socket socket;
  private BlockingQueue<Message> messageQueue;
  private ConcurrentMap<byte[], Boolean> clientManagement;
  private byte[] username;

  /**
   *
   *
   * @param socket
   * @param messageQueue
   * @param clientManagement
   */
  public ClientThread(Socket socket, BlockingQueue<Message> messageQueue, ConcurrentMap<byte[], Boolean> clientManagement) {
    this.socket = socket;
    this.messageQueue = messageQueue;
    this.clientManagement = clientManagement;
    this.username = new byte[1];
  }

  /**
   *
   */
  @Override
  public void run() {
    try (
        ObjectInputStream in = new ObjectInputStream(socket.getInputStream());
        ObjectOutputStream out = new ObjectOutputStream(socket.getOutputStream());
        ) {
      Message m = (Message) in.readObject();
      if (m.getIdentifier() == MessageIdentifier.CONNECT_MESSAGE) {
        ConnectMessage cm = (ConnectMessage) m;
        this.username = cm.getUsername();
        clientManagement.put(cm.getUsername(), true);
        byte[] cntMsg = "Successfully Connected!".getBytes();
        ConnectResponse cr = new ConnectResponse(true, cntMsg.length, cntMsg);
        out.writeObject(cr);
      } else {
        // Handle failing request
        handleFailedMsg(out, "Client should connect first");
        socket.close();
        throw new IllegalArgumentException("Connect unsuccessfully");
      }
      handleClient(in, out);
      socket.close();
    } catch (IOException | ClassNotFoundException | InterruptedException e) {
      e.printStackTrace();
    }
  }

  /**
   *
   * @param in
   * @param out
   * @throws InterruptedException
   * @throws IOException
   * @throws ClassNotFoundException
   */
  private void handleClient(ObjectInputStream in, ObjectOutputStream out)
      throws InterruptedException, IOException, ClassNotFoundException {
    while (true) {
      // Handle message taken from the queue
      Message msgFromQueue = messageQueue.take();
      handleQueueRequest(msgFromQueue, out);

      // Handle request obtained by user
      Message request = (Message) in.readObject();
      // Handle disconnect
      if (request.getIdentifier() == MessageIdentifier.DISCONNECT_MESSAGE) {
        DisconnectMessage dm = (DisconnectMessage) request;
        // If the message is not for this thread, then put it back and ignore it.
        if (!Arrays.equals(username, dm.getUsername())) {
          messageQueue.add(request);
          continue;
        }
        // Check if the username is inside the client map
        if (!clientManagement.containsKey(dm.getUsername())) {
          handleFailedMsg(out, "The client doesn't exist");
        }
        // Disconnect
        clientManagement.remove(dm.getUsername());
        // Create disconnect response
        byte[] message = "See you again".getBytes();
        DisconnectResponse dr = new DisconnectResponse(true, message.length, message);
        // Write to the client
        out.writeObject(dr);
        break;
      }
      // Handle other
      if (!handleRequest(request, out)) {
        handleFailedMsg(out, "The request failed due to incorrect username.");
      }
    }
  }

  /**
   *
   * @param request
   * @param out
   * @return
   * @throws IOException
   */
  private boolean handleRequest(Message request, ObjectOutputStream out) throws IOException {
    switch (request.getIdentifier()) {
      // If broadcast, then every one should know
      case BROADCAST_MESSAGE:
        BroadcastMessage bm = (BroadcastMessage) request;
        if (!Arrays.equals(username, bm.getUsername())) {
          return false;
        }
        messageQueue.add(request);
        break;
      // If user want the list of connected users
      case QUERY_CONNECTED_USERS:
        QueryUsersMessage qu = (QueryUsersMessage) request;
        if (!Arrays.equals(username, qu.getUsername())) {
          return false;
        }
        List<Pair<Integer, byte[]>> userList = new ArrayList<>();
        for (byte[] username : clientManagement.keySet()) {
          Pair<Integer, byte[]> user = new Pair<>(username.length, username);
          userList.add(user);
        }
        // Create a new query response containing all the users
        QueryResponse qr = new QueryResponse(clientManagement.keySet().size(), userList);
        out.writeObject(qr);
        break;
      // If user wants to send a direct message to the other user
      case DIRECT_MESSAGE:
        DirectMessage dm = (DirectMessage) request;
        if (!Arrays.equals(username, dm.getUsername())) {
          return false;
        }
        messageQueue.add(request);
        break;
      // If user wants to send an insult to the other user and broadcast to the chat room
      case SEND_INSULT:
        SendInsultMessage si = (SendInsultMessage) request;
        if (!Arrays.equals(username, si.getUsername())) {
          return false;
        }
        messageQueue.add(request);
        break;
    }
    return true;
  }

  /**
   *
   * @param out
   * @param description
   * @throws IOException
   */
  public void handleFailedMsg(ObjectOutputStream out, String description) throws IOException {
    byte[] failedMsg = description.getBytes();
    FailedMessage fm = new FailedMessage(failedMsg.length, failedMsg);
    out.writeObject(fm);
  }

  /**
   *
   * @param request
   * @param out
   * @throws IOException
   */
  public void handleQueueRequest(Message request, ObjectOutputStream out) throws IOException {
    switch (request.getIdentifier()) {
      case SEND_INSULT:
        // Gets the message from queue
        SendInsultMessage si = (SendInsultMessage) request;
        // Check if the user already gotten the message
        if (!si.getOtherUsers().contains(username)) {
          out.writeObject(si);
          si.addUsers(username);
        }
        // Check if all the users already gotten the message
        if (si.getOtherUsers().size() < clientManagement.keySet().size()) {
          messageQueue.add(si);
        }
        break;
      case DIRECT_MESSAGE:
        DirectMessage dm = (DirectMessage) request;
        // Check if the message is for this user
        if (Arrays.equals(username, dm.getRecipientUsername())) {
          out.writeObject(dm);
        } else { // If not for this user then put it back
          messageQueue.add(dm);
        }
        break;
      case BROADCAST_MESSAGE:
        // Gets the message from queue
        BroadcastMessage bm = (BroadcastMessage) request;
        // Check if the user already gotten the message
        if (!bm.getOtherUsers().contains(username)) {
          out.writeObject(bm);
          bm.addUsers(username);
        }
        // Check if all the users already gotten the message
        if (bm.getOtherUsers().size() < clientManagement.keySet().size()) {
          messageQueue.add(bm);
        }
        break;
    }
  }

I want to do JUnit test for my server. What is the best way to test a multi-threading server like this?

Here are the JUnit test code that I am trying. I first start a thread which is accepted by the server. Then I am going to start a client and pretend that the client is sending something to the server. I first want to try a connect_message to see how connection work. But so far, the test doesn't seem to responding on JUnit test. It just keeps running, nothing happen

public class ClientThreadTest {

  private Thread foo;
  private List<ClientThread> clientList;
  private BlockingQueue<Message> messageQueue;
  private ConcurrentMap<byte[], Boolean> clientManagement;
  private static final int PORT = 3000;

  @Before
  public void setUp() throws Exception {
    messageQueue = new LinkedBlockingQueue<>();
    clientManagement = new ConcurrentHashMap<>();
  }

  @Test
  public void run() throws IOException, ClassNotFoundException {
    ServerSocket socket = new ServerSocket(PORT);
    foo = new Thread(new ClientThread(socket.accept(), messageQueue, clientManagement));
    foo.start();
    Socket fooClient = new Socket("localhost", PORT);
    ObjectOutputStream out = new ObjectOutputStream(fooClient.getOutputStream());
    ObjectInputStream in = new ObjectInputStream(fooClient.getInputStream());
    // First test Connection message
    byte[] username = "foo".getBytes();
    ConnectMessage cm = new ConnectMessage(username.length, username);
    // The message need to get
    byte[] cntMsg = "Successfully Connected!".getBytes();
    ConnectResponse cr = new ConnectResponse(true, cntMsg.length, cntMsg);
    out.writeObject(cm);
    ConnectResponse m = (ConnectResponse) in.readObject();
    System.out.println(Arrays.toString(m.getMessage()));
  }

Solution

  • I have solved my own problem!

    For anyone who is doing JUnit testing on a multi-threading server. Here is my suggestion:

    You have to start you main server at the beginning, before anything else. Keep you main server listening to some port that you give it.

    Then you have to start your client and you have to give it the same port number which you gave to the main server

    Last but not least, you can start your thread to deal with a specific client. Somehow if I instantiated my thread as ClientThread foo, and I called foo.run(), it won't work. I have to instantiate Thread foo, and make my ClientThread() as an input to Thread(), and call foo.start() instead!

    Now it is working!