Search code examples
haskelltcphaskell-snap-framework

In Haskell, how can I abort a calculation when a web client disconnects


I have a Haskell-based web service that performs a calculation that for some input can take a really long time to finish. ("really long" here means over a minute)

Because performing that calculation takes all the CPU available on the server, I place incoming requests in a queue (well, actually a stack for reasons that have to do with the typical client, but that's besides the point) when they arrive and service them when the currently running calculation finishes.

My problem is that the clients don't always wait long enough, and sometimes time out on their end, disconnect, and try a different server (well, they try again and hit the elb, and usually get a different instance). Also, occasionally the calculation the web client was asking for will become obsolete because of external factors and the web client will be killed.

In those cases I'd really like to be able to detect that the web client has gone away before I pull the next request off the stack and start the (expensive) calculation. Unfortunately, my experience with snap leads me to believe that there's no way in that framework to ask "is the client's TCP connection still connected?" and I haven't found any documentation for other web frameworks that cover the "client disconnected" case.

So is there a Haskell web framework that makes it easy to detect whether a web client has disconnected? Or failing that, is there one that at least makes it possible?

(I understand that it may not be possible to be absolutely certain in all cases whether a TCP client is still there without sending data to the other end; however, when the client actually sends RST packets to the server and the server's framework doesn't let the application code determine that the connection is gone, that's a problem)


Incidentally, though one might suspect that warp's onClose handler would let you do this, this fires only when a response is ready and written to the client so is useless as a way of aborting a calculation in progress. There also seems to be no way to get access to the accepted socket so as to set SO_KEEPALIVE or similar. (There are ways to access the initial listening socket, but not the accepted one)


Solution

  • So I found an answer that works for me and it might work for someone else.

    It turns out that you can in fact mess around enough with the internals of Warp to do this, but then what you're left with is a basic version of Warp and if you need things like logging, etc., will need to add other packages on to that.

    Also, note that so-called "half-closed" connections (when the client closes their sending end, but is still waiting for data) will be detected as closed, interrupting your calculation. I don't know of any HTTP clients that deal in half-closed connections, but just something to be aware of.

    Anyway, what I did was first copy the functions runSettings and runSettingsSocket exposed by Network.Wai.Handler.Warp and Network.Wai.Handler.Warp.Internal and made versions that called a function I supplied instead of WarpI.socketConnection, so that I have the signature:

    runSettings' :: Warp.Settings -> (Socket -> IO (IO WarpI.Connection))
                 -> Wai.Application -> IO ()
    

    This required copying out a few helper methods, like setSocketCloseOnExec and windowsThreadBlockHack. The double-IO signature there might look weird, but it's what you want - the outer IO is run in the main thread (that calls accept) and the inner IO is run in the per-connection thread that is forked after accept returns. The original Warp function runSettings is equivalent to:

    \set -> runSettings' set (WarpI.socketConnection >=> return . return)
    

    Then I did:

    data ClientDisappeared = ClientDisappeared deriving (Show, Eq, Enum, Ord)
    instance Exception ClientDisappeared
    
    runSettingsSignalDisconnect :: Warp.Settings -> Wai.Application -> IO ()
    runSettingsSignalDisconnect set =
      runSettings' set (WarpI.socketConnection >=> return . wrapConn)
      where
        -- Fork a 'monitor' thread that does nothing but attempt to
        -- perform a read from conn in a loop 1/sec, and wrap the receive
        -- methods on conn so that they first consume from the stuff read
        -- by the monitoring thread. If the monitoring thread sees
        -- end-of-file (signaled by an empty string read), raise
        -- ClientDisappered on the per-connection thread.
        wrapConn conn = do
          tid <- myThreadId
          nxtBstr <- newEmptyMVar :: IO (MVar ByteString)
          semaphore <- newMVar ()
          readerCount <- newIORef (0 :: Int)
          monitorThread <- forkIO (monitor tid nxtBstr semaphore readerCount)
          return $ conn {
            WarpI.connClose = throwTo monitorThread ClientDisappeared
                              >> WarpI.connClose conn
            , WarpI.connRecv = newRecv nxtBstr semaphore readerCount
            , WarpI.connRecvBuf = newRecvBuf nxtBstr semaphore readerCount
            }
          where
            newRecv :: MVar ByteString -> MVar () -> IORef Int
                    -> IO ByteString
            newRecv nxtBstr sem readerCount =
              bracket_
              (atomicModifyIORef' readerCount $ \x -> (succ x, ()))
              (atomicModifyIORef' readerCount $ \x -> (pred x, ()))
              (withMVar sem $ \_ -> do w <- tryTakeMVar nxtBstr
                                       case w of
                                         Just w' -> return w'
                                         Nothing -> WarpI.connRecv conn
              )
    
            newRecvBuf :: MVar ByteString -> MVar () -> IORef Int
                       -> WarpI.Buffer -> WarpI.BufSize -> IO Bool
            newRecvBuf nxtBstr sem readerCount buf bufSize =
              bracket_
              (atomicModifyIORef' readerCount $ \x -> (succ x, ()))
              (atomicModifyIORef' readerCount $ \x -> (pred x, ()))
              (withMVar sem $ \_ -> do
                  (fulfilled, buf', bufSize') <-
                    if bufSize == 0 then return (False, buf, bufSize)
                    else
                      do w <- tryTakeMVar nxtBstr
                         case w of
                           Nothing -> return (False, buf, bufSize)
                           Just w' -> do
                             let wlen = B.length w'
                             if wlen > bufSize
                               then do BU.unsafeUseAsCString w' $ \cw' ->
                                         copyBytes buf (castPtr cw') bufSize
                                       putMVar nxtBstr (B.drop bufSize w')
                                       return (True, buf, 0)
                               else do BU.unsafeUseAsCString w' $ \cw' ->
                                         copyBytes buf (castPtr cw') wlen
                                       return (wlen == bufSize, plusPtr buf wlen,
                                               bufSize - wlen)
                  if fulfilled then return True
                    else WarpI.connRecvBuf conn buf' bufSize'
              )
            dropClientDisappeared :: ClientDisappeared -> IO ()
            dropClientDisappeared _ = return ()
            monitor tid nxtBstr sem st =
              catch (monitor' tid nxtBstr sem st) dropClientDisappeared
    
            monitor' tid nxtBstr sem st = do
              (hitEOF, readerCount) <- withMVar sem $ \_ -> do
                w <- tryTakeMVar nxtBstr
                case w of
                  -- No one picked up our bytestring from last time
                  Just w' -> putMVar nxtBstr w' >> return (False, 0)
                  Nothing -> do
                    w <- WarpI.connRecv conn
                    putMVar nxtBstr w
                    readerCount <- readIORef st
                    return (B.null w, readerCount)
              if hitEOF && (readerCount == 0)
                -- Don't signal if main thread is also trying to read -
                -- in that case, main thread will see EOF directly
                then throwTo tid ClientDisappeared
                else do threadDelay oneSecondInMicros
                        monitor' tid nxtBstr sem st
            oneSecondInMicros = 1000000