Search code examples
javawebsocketundertow

Send large data set using Undertow WebSockets efficiently


I have a large ConcurrentHashMap (cache.getCache()) where I hold all my data (approx. 500+ MB size but this can grow over time). This is accessible to clients via a API implemented using plain java HttpServer. Here is the simplified code:

JsonWriter jsonWriter = new JsonWriter(new OutputStreamWriter(new BufferedOutputStream(new GZIPOutputStream(exchange.getResponseBody())))));
new GsonBuilder().create().toJson(cache.getCache(), CacheContainer.class, jsonWriter);

There are also some filters the clients send so they dont actually get all the data every time, but the HashMap gets constantly updated so clients have to refresh often to have the latest data. This is inefficient so I decided to push the data updates to clients in real-time using WebSockets.

I chose Undertow for this because I can simply import it from Maven and there is no extra configuration I have to do on the server.

On WS connect I add the channel to a HashSet and send the whole dataset (the client sends a message with some filters before getting initial data, but I removed this part from the example):

public class MyConnectionCallback implements WebSocketConnectionCallback {
  CacheContainer cache;
  Set<WebSocketChannel> clients = new HashSet<>();
  BlockingQueue<String> queue = new LinkedBlockingQueue<>();

  public MyConnectionCallback(CacheContainer cache) {
    this.cache = cache;
    Thread pusherThread = new Thread(() -> {
      while (true) {
        push(queue.take());
      }
    });
    pusherThread.start();
  }

  public void onConnect(WebSocketHttpExchange webSocketHttpExchange, WebSocketChannel webSocketChannel) {
    webSocketChannel.getReceiveSetter().set(new AbstractReceiveListener() {
      protected void onFullTextMessage(WebSocketChannel channel, BufferedTextMessage message) {
        clients.add(webSocketChannel);
        WebSockets.sendText(gson.toJson(cache.getCache()), webSocketChannel, null);
      }
    }
  }

  private void push(String message) {
    Set<WebSocketChannel> closed = new HashSet<>();
    clients.forEach((webSocketChannel) -> {
        if (webSocketChannel.isOpen()) {
            WebSockets.sendText(message, webSocketChannel, null);
        } else {
            closed.add(webSocketChannel);
        }
    }
    closed.foreach(clients::remove);
  }

  public void putMessage(String message) {
    queue.put(message);
  }
}

After every change to my cache I get the new value and put it into the queue (I do not directly serialize the myUpdate object because there is other logic behind that in the updateCache method). There is only one thread responsible for updating the cache:

cache.updateCache(key, myUpdate);
Map<Key,Value> tempMap = new HashMap<>();
tempMap.put(key, cache.getValue(key));
webSocketServer.putMessage(gson.toJson(tempMap));

The problems I see with this approach:

  1. on initial connect the whole dataset is converted to a String and I fear too many requests could cause the server to become OOM. WebSockets.sendText only accepts String and ByteBuffer
  2. if I add the channel to the clients set first and then send the data, a push could go to the client before the initial data is sent, and the client would be in a invalid state
  3. if I send the initial data first and then add the channel to the clients set, push messages that come during the sending of initial data will be lost, and the client would be in a invalid state

The solution I came up with for problems #2 and #3 is to put the messages in a Queue (I would convert the Set<WebSocketChannel> into Map<WebSocketChannel,Queue<String>> and send the messages in the queue only after the client receives the initial dataset, but I welcome any other suggestions here.

As for problem #1 my question is what would be the most efficient way to send the initial data over WebSocket? For example something like writing with the JsonWriter directly to the WebSocket.

I realize the clients could make the initial call using the API and subscribe to the WebSocket for changes, but this approach makes the clients responsible for having the correct state (they need to subscribe to WS, queue WS messages, get initial data using API, and then apply queued WS messages to their dataset after getting initial data) and I don't want to leave the control over that up to them because the data is sensitive.


Solution

  • To resolve problems #2 and #3 I set a push lock flag on each client that gets unlocked only when the initial data is sent. When the push lock is set the messages that arrive are placed in that clients queue. Queued messages are then sent before any new messages.

    I mitigated problem #1 by using ByteBuffer directly instead of String. This way I can save some memory because of encoding (String uses UTF-16 by default)

    Final code:

    public class WebSocketClient {
      private boolean pushLock;
      private Gson gson;
      private Queue<CacheContainer> queue = new ConcurrentLinkedQueue<>();
    
      WebSocketClient(MyQuery query, CacheHandler cacheHandler) {
        pushLock = true;
        this.gson = GsonFactory.getGson(query, cacheHandler);
      }
    
      public synchronized boolean isPushLock() {
        return pushLock;
      }
    
      public synchronized void pushUnlock() {
        pushLock = false;
      }
    
      public Gson getGson() {
        return gson;
      }
    
      public Queue<CacheContainer> getQueue() {
        return queue;
      }
    
      public boolean hasBackLog() {
        return !queue.isEmpty();
      }
    }
    
    public class MyConnectionCallback implements WebSocketConnectionCallback {
    
      private final Map<WebSocketChannel, WebSocketClient> clients = new ConcurrentHashMap<>();
      private final BlockingQueue<CacheContainer> messageQueue = new LinkedBlockingQueue<>();
    
      private final Gson queryGson = new GsonBuilder().disableHtmlEscaping().create();
    
      private final CacheHandler cacheHandler;
    
      MyConnectionCallback(CacheHandler cacheHandler) {
        this.cacheHandler = cacheHandler;
        Thread pusherThread = new Thread(() -> {
          boolean hasPushLock = false;
          while (true) {
            if (messageQueue.isEmpty() && hasPushLock) hasPushLock = pushToAllClients(null);
            else hasPushLock = pushToAllClients(messageQueue.take());
          }
        }, "PusherThread");
        pusherThread.start();
      }
    
      @Override
      public void onConnect(WebSocketHttpExchange webSocketHttpExchange, WebSocketChannel webSocketChannel) {
        webSocketChannel.getReceiveSetter().set(new AbstractReceiveListener() {
          @Override
          protected void onFullTextMessage(WebSocketChannel channel, BufferedTextMessage message) throws IOException {
            MyQuery query = new MyQuery(queryGson.fromJson(message.getData(), QueryJson.class));
            WebSocketClient clientConfig = new WebSocketClient(query, cacheHandler);
            clients.put(webSocketChannel, clientConfig);
            push(webSocketChannel, clientConfig.getGson(), cacheHandler.getCache());
            clientConfig.pushUnlock();
            }
        });
        webSocketChannel.resumeReceives();
      }
    
      void putMessage(CacheContainer message) {
        messageQueue.put(message);
      }
    
      private synchronized void push(WebSocketChannel webSocketChannel, Gson gson, CacheContainer message) throws IOException {
        try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
          JsonWriter jsonWriter = new JsonWriter(new OutputStreamWriter(baos, StandardCharsets.UTF_8))) {
          gson.toJson(message, CacheContainer.class, jsonWriter);
          jsonWriter.flush();
          if (baos.size() > 2) {
            WebSockets.sendText(ByteBuffer.wrap(baos.toByteArray()), webSocketChannel, null);
          }
        }
      }
    
      private synchronized boolean pushToAllClients(CacheContainer message) {
        AtomicBoolean hadPushLock = new AtomicBoolean(false);
        Set<WebSocketChannel> closed = new HashSet<>();
    
        clients.forEach((webSocketChannel, clientConfig) -> {
          if (webSocketChannel.isOpen()) {
            if (clientConfig.isPushLock()) {
              hadPushLock.set(true);
              clientConfig.getQueue().add(message);
            } else {
              try {
                if (clientConfig.hasBackLog())
                  pushBackLog(webSocketChannel, clientConfig);
                if (message != null)
                  push(webSocketChannel, clientConfig.getGson(), message);
              } catch (Exception e) {
                closeChannel(webSocketChannel, closed);
              }
            }
          } else {
            closed.add(webSocketChannel);
          }
        });
    
        closed.forEach(clients::remove);
        return hadPushLock.get();
      }
    
      private void pushBackLog(WebSocketChannel webSocketChannel, WebSocketClient clientConfig) throws IOException {
        while (clientConfig.hasBackLog()) {
          push(webSocketChannel, clientConfig.getGson(), clientConfig.getQueue().poll());
        }
      }
    
      private void closeChannel(WebSocketChannel channel, Set<WebSocketChannel> closed) {
        closed.add(channel);
        channel.close();
      }
    }