Search code examples
tomcatservletswebsocketproxy

Proxy regular HTTP and WebSocket via Tomcat servlet


I'm implementing a web application which, among other things, has to show and interact with webpages proxied to backend services. For this, I'm using the HTTP-Proxy-Servlet which works well most of the time.

However, certain backend services' webpages use websockets and the proxy servlet above doesn't support websockets.

I tried implementing it by reconstructing the websocket call towards the backend and then copying between streams, but that doesn't work. The browser reports "Invalid frame header" and Tomcat fails with

Error parsing HTTP request header
Invalid character found in method name. HTTP method names must be tokens
at org.apache.coyote.http11.Http11InputBuffer.parseRequestLine(Http11InputBuffer.java:414)

My code:

import java.io.IOException;
import java.net.*;
import java.nio.charset.StandardCharsets;
import java.util.concurrent.*;

import javax.servlet.ServletException;
import javax.servlet.http.*;

import org.apache.http.HttpRequest;
import org.mitre.dsmiley.httpproxy.ProxyServlet;

public class ProxyWithWebSocket extends ProxyServlet {

    private static final long serialVersionUID = -2566573965489129976L;

    protected ExecutorService exec;
    
    @Override
    public void init() throws ServletException {
        super.init();
        exec = Executors.newCachedThreadPool();
    }
    
    @Override
    public void destroy() {
        super.destroy();
        exec.shutdown();
    }

    @Override
    protected void service(HttpServletRequest servletRequest, HttpServletResponse servletResponse)
            throws ServletException, IOException {
        var wsKey = servletRequest.getHeader("Sec-WebSocket-Key");
        if (wsKey != null) {
            //initialize request attributes from caches if unset by a subclass by this point
            if (servletRequest.getAttribute(ATTR_TARGET_URI) == null) {
              servletRequest.setAttribute(ATTR_TARGET_URI, targetUri);
            }
            if (servletRequest.getAttribute(ATTR_TARGET_HOST) == null) {
              servletRequest.setAttribute(ATTR_TARGET_HOST, targetHost);
            }
            String proxyRequestUri = rewriteUrlFromRequest(servletRequest);
            URL u = new URL(proxyRequestUri);

            var servletIn = servletRequest.getInputStream();
            var servletOut = servletResponse.getOutputStream();

            try (Socket sock = new Socket(u.getHost(), u.getPort())) {
                var sockIn = sock.getInputStream();
                var sockOut = sock.getOutputStream();
                
                StringBuilder req = new StringBuilder(512);
                req.append("GET " + u.getFile()).append(" HTTP/1.1");
                System.out.println("  > WS|" + req);
                req.append("\r\n");
                var en = servletRequest.getHeaderNames();
                while (en.hasMoreElements()) {
                    var n = en.nextElement();
                    String header = servletRequest.getHeader(n);
                    System.out.println("  > WS| " + n + ": " + header);
                    req.append(n + ": " + header + "\r\n");
                }
                req.append("\r\n");
                
                sockOut.write(req.toString().getBytes(StandardCharsets.UTF_8));
                sockOut.flush();
    
                StringBuilder responseBytes = new StringBuilder(512);
                int b = 0;
                while (b != -1) {
                    b = sockIn.read();
                    if (b != -1) {
                        responseBytes.append((char)b);
                        var len = responseBytes.length();
                        if (len >= 4
                                && responseBytes.charAt(len - 4) == '\r'
                                && responseBytes.charAt(len - 3) == '\n'
                                && responseBytes.charAt(len - 2) == '\r'
                                && responseBytes.charAt(len - 1) == '\n'
                        ) {
                            break;
                        }
                    }
                }
                
                String[] rows = responseBytes.toString().split("\r\n"); 
                
                String response = rows[0];
                System.out.println("  < WS|" + response);
                
                int idx1 = response.indexOf(' ');
                int idx2 = response.indexOf(' ', idx1 + 1);
                
                for (int i = 1; i < rows.length; i++) {
                    String line = rows[i];
                    int idx3 = line.indexOf(":");
                    var k = line.substring(0, idx3);
                    var headerField = line.substring(idx3 + 2);
                    System.out.println("  < WS| " + k + ": " + headerField);
                    servletResponse.setHeader(k, headerField);
                }
                
                servletResponse.setStatus(Integer.parseInt(response.substring(idx1 + 1, idx2)));
                servletResponse.flushBuffer();
                
                System.out.println("  < WS| Flush");
    
                var f1 = exec.submit(() -> {
                    var c = 0;
                    
                    var bs = 0;
                    while ((bs = servletIn.read()) != -1) {
                        sockOut.write(bs);
                        c++;
                    }
                    System.out.println("  > WS| Done: " + c);
                    return null;
                });
                var f2 = exec.submit(() -> {
                    var c = 0;
                    
                    var bs = 0;
                    while ((bs = sockIn.read()) != -1) {
                        servletOut.write(bs);
                        servletOut.flush();
                        c++;
                    }
                    System.out.println("  < WS| Done: " + c);
                    return null;
                });
    
                try {
                    f1.get();
                } catch (Exception ex) {
                    f2.cancel(true);
                    return;
                }
                try {
                    f2.get();
                } catch (Exception ex) {
                    
                }
            }
        } else {
            super.service(servletRequest, servletResponse);
        }
    }
}

A typical exchange looks like this (via those println):

  > WS|GET /cellhub?id=NhWO8SnGyDb_Vrk23rmhVQ HTTP/1.1
  > WS| host: localhost:8080
  > WS| connection: Upgrade
  > WS| pragma: no-cache
  > WS| cache-control: no-cache
  > WS| user-agent: Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/94.0.4606.71 Safari/537.36
  > WS| upgrade: websocket
  > WS| origin: http://localhost:8080
  > WS| sec-websocket-version: 13
  > WS| accept-encoding: gzip, deflate, br
  > WS| accept-language: hu,hu-HU;q=0.9,en-US;q=0.8,en;q=0.7
  > WS| cookie: JSESSIONID=57E4B30452BC3EB2657139DAF70E65AD; JSESSIONID=AD5E7BB5FE17B4072F3ABEE32B9479AC
  > WS| sec-websocket-key: nrZWEb6Co4DKggUNwPeV8g==
  > WS| sec-websocket-extensions: permessage-deflate; client_max_window_bits
  < WS|HTTP/1.1 101 Switching Protocols
  < WS| Connection:  Upgrade
  < WS| Date:  Thu, 07 Oct 2021 13:18:41 GMT
  < WS| Server:  Kestrel
  < WS| Upgrade:  websocket
  < WS| Sec-WebSocket-Accept:  /9uN8ZF67WepGJQ3+DPBLMCBotc=
  < WS| Flush
  > WS| Done: 0
  < WS| Done: 42

How can I make this work?

Edit

I found the HttpServletRequest.upgrade method which appears to be for changing protocols. I've updated the part after the header copying:

                int respCode = Integer.parseInt(response.substring(idx1 + 1, idx2));
                if (respCode != 101) {
                    servletResponse.setStatus(respCode);
                    servletResponse.flushBuffer();
                    System.out.println("  < WS| Flush");
                    closeSocket = true;
                } else {
                    var uh = servletRequest.upgrade(WsUpgradeHandler.class);
                    uh.preInit(exec, sockIn, sockOut, sock);
                }

Where WsUpgradeHandler is

    public static class WsUpgradeHandler implements HttpUpgradeHandler {

        ExecutorService exec;
        InputStream sockIn;
        OutputStream sockOut;
        Socket sock;
        Future<?> f1;
        Future<?> f2;
        
        public WsUpgradeHandler() { }
        
        public void preInit(ExecutorService exec, InputStream sockIn, OutputStream sockOut, Socket sock) {
            this.exec = exec;
            this.sockIn = sockIn;
            this.sockOut = sockOut;
            this.sock = sock;
        }
        
        @Override
        public void init(WebConnection wc) {
            System.out.println("  * WS| Upgrade begin");
            try {
                var servletIn = wc.getInputStream();
                var servletOut = wc.getOutputStream();
                f1 = exec.submit(() -> {
                    System.out.println("  > WS| Client -> Backend");
                    var c = 0;
                    
                    var bs = 0;
                    try {
                        while ((bs = servletIn.read()) != -1) {
                            sockOut.write(bs);
                            c++;
                        }
                    } catch (Exception exc) {
                        exc.printStackTrace();
                    } finally {
                        sockOut.close();
                    }
                    System.out.println("  > WS| Done: " + c);
                    return null;
                });
                f2 = exec.submit(() -> {
                    System.out.println("  > WS| Backend -> Client");
                    var c = 0;
                    
                    try {
                        var bs = 0;
                        while ((bs = sockIn.read()) != -1) {
                            servletOut.write(bs);
                            servletOut.flush();
                            c++;
                        }
                    } catch (Exception exc) {
                        exc.printStackTrace();
                    } finally {
                        servletOut.close();
                    }
                    System.out.println("  < WS| Done: " + c);
                    return null;
                });

            } catch (IOException ex) {
                ex.printStackTrace();
            }
        }

        @Override
        public void destroy() {
            System.out.println("  * WS| Upgrade closing");
            f1.cancel(true);
            f2.cancel(true);
            try {
                sock.close();
            } catch (IOException ex) {
                
            }
            System.out.println("  * WS| Upgrade close");
        }
        
    }

This does work for passing messages around but if the websocket connection from the browser ends, Tomcat's CPU utilization goes very high (no other activity should be happening) at this point. It appears some or all of Tomcat's NIO theads are spinning and the thread pool I'm using has no threads any longer.


Solution

  • I think I managed to solve the issue.

    The code above was almost correct with one exception: apparently the init() method should not return when using blocking mode as demonstrated by this Tomcat test example.

    The second issue, namely the high CPU usage was tracked down to a poller thread in tomcat that had bugs before. I was running my code in Tomcat 9.0.12 and once upgraded to Tomcat 9.0.54, the CPU usage issues went away.

    Thus the complete working code looks like this: (I know, I know, byte showeling and manually preparing HTML requrests is not optimal, but that's what Loom is for, right ;)

    import java.io.*;
    import java.net.*;
    import java.nio.charset.StandardCharsets;
    import java.util.concurrent.*;
    
    import javax.servlet.ServletException;
    import javax.servlet.http.*;
    
    import org.apache.http.HttpRequest;
    import org.mitre.dsmiley.httpproxy.ProxyServlet;
    
    public class ProxyWithWebSocket extends ProxyServlet {
    
        private static final long serialVersionUID = -2566573965489129976L;
    
        protected ExecutorService exec;
        
        @Override
        public void init() throws ServletException {
            super.init();
            exec = Executors.newCachedThreadPool();
        }
        
        @Override
        public void destroy() {
            super.destroy();
            exec.shutdown();
        }
        
        @Override
        protected void copyRequestHeaders(HttpServletRequest servletRequest, HttpRequest proxyRequest) {
            super.copyRequestHeaders(servletRequest, proxyRequest);
            
            String userId = (String)servletRequest.getAttribute("UserID");
            if (userId != null) {
                proxyRequest.addHeader("UserID", userId);
            }
        }
    
        @Override
        protected void service(HttpServletRequest servletRequest, HttpServletResponse servletResponse)
                throws ServletException, IOException {
            var wsKey = servletRequest.getHeader("Sec-WebSocket-Key");
            if (wsKey != null) {
                
                //initialize request attributes from caches if unset by a subclass by this point
                if (servletRequest.getAttribute(ATTR_TARGET_URI) == null) {
                  servletRequest.setAttribute(ATTR_TARGET_URI, targetUri);
                }
                if (servletRequest.getAttribute(ATTR_TARGET_HOST) == null) {
                  servletRequest.setAttribute(ATTR_TARGET_HOST, targetHost);
                }
                String proxyRequestUri = rewriteUrlFromRequest(servletRequest);
                URL u = new URL(proxyRequestUri);
    
                Socket sock = new Socket(u.getHost(), u.getPort());
                boolean closeSocket = false;
                try {
                    var sockIn = sock.getInputStream();
                    var sockOut = sock.getOutputStream();
                    
                    StringBuilder req = new StringBuilder(512);
                    req.append("GET " + u.getFile()).append(" HTTP/1.1");
                    System.out.println("  > WS|" + req);
                    req.append("\r\n");
                    var en = servletRequest.getHeaderNames();
                    while (en.hasMoreElements()) {
                        var n = en.nextElement();
                        String header = servletRequest.getHeader(n);
                        System.out.println("  > WS| " + n + ": " + header);
                        req.append(n + ": " + header + "\r\n");
                    }
                    req.append("\r\n");
                    
                    sockOut.write(req.toString().getBytes(StandardCharsets.UTF_8));
                    sockOut.flush();
        
                    StringBuilder responseBytes = new StringBuilder(512);
                    int b = 0;
                    while (b != -1) {
                        b = sockIn.read();
                        if (b != -1) {
                            responseBytes.append((char)b);
                            var len = responseBytes.length();
                            if (len >= 4
                                    && responseBytes.charAt(len - 4) == '\r'
                                    && responseBytes.charAt(len - 3) == '\n'
                                    && responseBytes.charAt(len - 2) == '\r'
                                    && responseBytes.charAt(len - 1) == '\n'
                            ) {
                                break;
                            }
                        }
                    }
                    
                    String[] rows = responseBytes.toString().split("\r\n"); 
                    
                    String response = rows[0];
                    System.out.println("  < WS|" + response);
                    
                    int idx1 = response.indexOf(' ');
                    int idx2 = response.indexOf(' ', idx1 + 1);
                    
                    for (int i = 1; i < rows.length; i++) {
                        String line = rows[i];
                        int idx3 = line.indexOf(":");
                        var k = line.substring(0, idx3);
                        var headerField = line.substring(idx3 + 2);
                        System.out.println("  < WS| " + k + ": " + headerField);
                        servletResponse.setHeader(k, headerField);
                    }
                    
                    int respCode = Integer.parseInt(response.substring(idx1 + 1, idx2));
                    if (respCode != 101) {
                        servletResponse.setStatus(respCode);
                        servletResponse.flushBuffer();
                        System.out.println("  < WS| Flush");
                        closeSocket = true;
                    } else {
                        var uh = servletRequest.upgrade(WsUpgradeHandler.class);
                        uh.preInit(exec, sockIn, sockOut, sock);
                    }
        
                    
                } finally {
                    if (closeSocket) {
                        sock.close();
                    }
                }
            } else {
                super.service(servletRequest, servletResponse);
            }
        }
        
        public static class WsUpgradeHandler implements HttpUpgradeHandler {
    
            ExecutorService exec;
            InputStream sockIn;
            OutputStream sockOut;
            Socket sock;
            Future<?> f2;
            
            public WsUpgradeHandler() { }
            
            public void preInit(ExecutorService exec, InputStream sockIn, OutputStream sockOut, Socket sock) {
                this.exec = exec;
                this.sockIn = sockIn;
                this.sockOut = sockOut;
                this.sock = sock;
            }
            
            @Override
            public void init(WebConnection wc) {
                System.out.println("  * WS| Upgrade begin");
                try {
                    var servletIn = wc.getInputStream();
                    var servletOut = wc.getOutputStream();
                    f2 = exec.submit(() -> {
                        System.out.println("  > WS| Backend -> Client");
                        var c = 0;
                        
                        try {
                            var bs = 0;
                            while ((bs = sockIn.read()) != -1) {
                                servletOut.write(bs);
                                servletOut.flush();
                                c++;
                            }
                        } catch (SocketException | EOFException exc) {
                            // this is fine
                        } catch (Exception exc) {
                            exc.printStackTrace();
                        } finally {
                            servletOut.close();
                        }
                        System.out.println("  < WS| Done: " + c);
                        return null;
                    });
    
                    System.out.println("  > WS| Client -> Backend");
                    var c = 0;
                    
                    var bs = 0;
                    try {
                        while ((bs = servletIn.read()) != -1) {
                            sockOut.write(bs);
                            c++;
                        }
                    } catch (SocketException | EOFException exc) {
                        // this is fine
                    } catch (Exception exc) {
                        exc.printStackTrace();
                    } finally {
                        sockOut.close();
                    }
                    System.out.println("  > WS| Done: " + c);
    
                    f2.get();
                } catch (Exception ex) {
                    ex.printStackTrace();
                } finally {
                    if (f2 != null) {
                        f2.cancel(true);
                    }
                }
            }
    
            @Override
            public void destroy() {
                System.out.println("  * WS| Upgrade closing");
                if (f2 != null) {
                    f2.cancel(true);
                }
                try {
                    sock.close();
                } catch (IOException ex) {
                    
                }
                System.out.println("  * WS| Upgrade close");
            }
            
        }
    }