Search code examples
haskellservant

Servant: a callback that runs on any request


I use Servant. I need to set some callback function that will be executed on every request and it will forward the request processing/handling further (to other handler, as if this callback did not exist at all). I need it to change some IORef MyGlobalState in this callback. How to do it? Some example? Sure, I can call it explicitly in all handlers' bodies but maybe there is some right/canonical way to do it...


Solution

  • Consider this example server:

    {-# LANGUAGE DataKinds #-}
    {-# LANGUAGE TypeOperators #-}
    
    import Servant
    import Servant.API
    import Network.Wai
    import Network.Wai.Handler.Warp
    import Data.IORef
    import qualified Data.ByteString.Char8 as C
    
    type API = "one" :> Get '[PlainText] String :<|> "two" :> Get '[PlainText] String
    
    api :: Proxy API
    api = Proxy
    
    server :: Server API
    server = return "1\n" :<|> return "2\n"
    
    app :: Application
    app = serve api server
    
    main = run 3000 app
    

    The app value is a WAI Application, defined by:

    type Application = Request -> (Response -> IO ResponseReceived) -> IO ResponseReceived
    

    Conceptually, an Application accepts a Request, and invokes a callback on its Response. WAI supports the concept of Middleware which can wrap an application, allowing preprocessing of every request and post-processing of every response in the IO monad:

    type Middleware = Application -> Application
    

    So, you can write a piece of middleware to update an IORef on every received request like so:

    counter :: IORef Int -> Middleware
    counter cref = convert
    
      where
        convert :: Application -> Application  -- AKA Middleware
        convert oldapp = newapp
    
          where
            newapp :: Request -> (Response -> IO ResponseReceived) -> IO ResponseReceived  -- AKA Application
            newapp req respond = do
              n <- atomicModifyIORef cref (\n' -> (n'+1,n'+1))
              putStrLn $ "Request #" ++ show n ++ ": " ++ showRequest req
              oldapp req respond
    
            showRequest req = C.unpack (requestMethod req) ++ " " ++ C.unpack (rawPathInfo req)
    

    I've broken this up into separate functions to make it clear how the middleware is constructed piece by piece, but the definition of counter can be simplified to the equivalent:

    counter :: IORef Int -> Middleware
    counter cref oldapp req respond
      = do n <- atomicModifyIORef cref (\n' -> (n'+1,n'+1))
           putStrLn $ "Request #" ++ show n ++ ": " ++ showRequest req
           oldapp req respond
    
      where
        showRequest req = C.unpack (requestMethod req) ++ " " ++ C.unpack (rawPathInfo req)
    

    Now, all you need to do is wrap your app with the middleware in main:

    main = do
      cref <- newIORef (0 :: Int)
      run 3000 $ counter cref app
    

    Full code:

    {-# LANGUAGE DataKinds #-}
    {-# LANGUAGE TypeOperators #-}
    
    import Servant
    import Servant.API
    import Network.Wai
    import Network.Wai.Handler.Warp
    import Data.IORef
    import qualified Data.ByteString.Char8 as C
    
    type API = "one" :> Get '[PlainText] String :<|> "two" :> Get '[PlainText] String
    
    api :: Proxy API
    api = Proxy
    
    server :: Server API
    server = return "1\n" :<|> return "2\n"
    
    app :: Application
    app = serve api server
    
    counter :: IORef Int -> Middleware
    counter cref oldapp req respond
      = do n <- atomicModifyIORef cref (\n' -> (n'+1,n'+1))
           putStrLn $ "Request #" ++ show n ++ ": " ++ showRequest req
           oldapp req respond
    
      where
        showRequest req = C.unpack (requestMethod req) ++ " " ++ C.unpack (rawPathInfo req)
    
    main = do
      cref <- newIORef (0 :: Int)
      run 3000 $ counter cref app