Search code examples
exceptionhaskellconcurrencyghc

Run cleanup function in multiple Haskell child threads when POSIX Signal sent (SIGTERM etc)


TL;DR - how do I make the following work in Haskell:

  • Send a SIGTERM to a long-running program with many active threads (each working on a job)
  • Get all child threads run a cleanup function (updating the database to say the job has aborted), before they exit

To my (very inexperienced) mind, it seems like the cleanest way to make this happen is to trap the SIGTERM in the 'main' thread, raise asynchronous exceptions in the child threads, and then use bracket in the child threads to react to the asynchronous exception by running some cleanup code. Empirically I cannot make this work.


More colour:

I have a Haskell program that spawns a number of threads to do work (using async). Basically, it:

  • Waits on notifications from the database job queue for new jobs
  • Spawns a new thread to do the work in
    • The thread update the job status in the database as it progresses (e.g. running, paused)
    • If the job completes, the user cancels the job, or a synchronous exception occurs, it updates the database with the final state (e.g. completed, cancelled, aborted)

Crucially, the main thread runs forever, just listening for new jobs, unless interrupted by SIGINT, SIGTERM, SIGKILL etc.

When the program gets a SIGINT or SIGTERM, I want to run some cleanup (namely, updating the database to set the status of in-flight jobs to aborted) before the 'child' thread dies. However, I absolutely cannot figure out how to do this.

My understanding is that for handling exceptions thrown from a 'parent' thread to a 'child' thread is to use bracket, which masks out the async exception for the main body of work , allowing you to run a cleanup function prior to terminating.

However, bracket doesn't seem to interact well with signal handling. I have tried installing signal handlers, to try to get the SIGTERM converted into a runtime exception that I can handle properly. This works great for the thread in which I installed the handler, but I can't throw an asynchronous exception to other threads, I think because they've also received the SIGTERM, and so they just die immediately.

It also appears that I can't install an individual SIGTERM per-thread, because it looks like the runtime can only have one signal handler per interrupt type across all threads (basically, if I do this, the last thread to start gets the interrupt, but all other threads, including the main thread, keep running).


Edited to add

Here's some example code that I've developed based on @Li-yao Xia's answer (which was super helpful - thank you).

One piece of the puzzle is that I'm creating child threads inside a recursive function (which listens to notifications on a job queue, then potentially spawns new workers in response). However, I can't see how to get pass the list of child threads to the exception handler, unless I attach the handler to every call of the recursive function (see example code below). However, this means that I'm not able to use effectively use tail-call recursion, and if I terminate the program, the exception handler gets called once for every loop as the stack frame unwinds. Is there a better pattern to make this work?

import Control.Concurrent.Async (Async, async, cancelWith)

import Control.Exception (AsyncException (..), Exception, SomeException, catch, throwIO)
import System.Posix (Signal)
import System.Posix.Signals (Handler (..), installHandler, sigHUP, sigINT, sigTERM, sigUSR1, sigUSR2, sigXCPU, sigXFSZ)

import Control.Concurrent (myThreadId, threadDelay, throwTo)
import Control.Monad (forM_)
import Data.Data (Typeable)

import Data.Foldable (for_)

data Result = Done | Aborted deriving (Show)

termMsg :: Int -> Result -> IO ()
termMsg n s = putStrLn $ "Thread " ++ show n ++ " terminated with " ++ show s

thread :: Int -> IO ()
thread n = job `catch` asyncHandler
  where
    job = do
      for_ ([0 .. 9] :: [Int]) $ \_ -> do
        putStrLn $ "Thread " ++ show n ++ " alive..."
        threadDelay $ 5e5 * n
      termMsg n Done

    asyncHandler :: AsyncException -> IO ()
    asyncHandler _ = do
      termMsg n Aborted

parent :: IO ()
parent = run 0 []
  where
    run :: Int -> [Async ()] -> IO ()
    run n w = do
      ( do
          putStrLn $ "Main thread alive (loop " ++ show n ++ ")"
          t <- async (thread n)
          let nw = w ++ [t]
          threadDelay 1e6
          run (n + 1) nw
        )
        `catch` handler w

    handler :: [Async ()] -> SomeException -> IO ()
    handler children e = do
      print e
      cleanupChildren children
      throwIO e

main :: IO ()
main = do
  installSignalHandlers
  parent `catch` someExceptionHandler

cleanupChildren :: [Async ()] -> IO ()
cleanupChildren children = do
  putStrLn "Cleaning up children..."
  for_ children $ \t -> cancelWith t ThreadKilled

someExceptionHandler :: SomeException -> IO ()
someExceptionHandler e = do
  putStrLn $ "Terminating with " ++ show e
  throwIO e

data SignalException = SignalException Signal String
  deriving (Show, Typeable, Eq)
instance Exception SignalException

signalsToHandle :: [(Signal, String)]
signalsToHandle = [(sigHUP, "SIGHUP"), (sigINT, "SIGINT"), (sigTERM, "SIGTERM"), (sigUSR1, "SIGUSR1"), (sigUSR2, "SIGUSR2"), (sigXCPU, "SIGXCPU"), (sigXFSZ, "SIGXFSZ")]

installSignalHandlers :: IO ()
installSignalHandlers = do
  mainId <- myThreadId
  forM_ signalsToHandle $ \(sig, name) -> installHandler sig (Catch (throwTo mainId $ SignalException sig name)) Nothing

Result:

Main thread alive (loop 0)
Thread 0 alive...

...

Main thread alive (loop 7)
Thread 7 alive...
Thread 3 alive...
Thread 5 alive...
Thread 4 alive...
Thread 2 alive...
Main thread alive (loop 8)
Thread 8 alive...
^CSignalException 2 "SIGINT"
Cleaning up children...
Thread 2 terminated with Aborted
Thread 3 terminated with Aborted
Thread 4 terminated with Aborted
Thread 5 terminated with Aborted
Thread 6 terminated with Aborted
Thread 7 terminated with Aborted
SignalException 2 "SIGINT"
Cleaning up children...
SignalException 2 "SIGINT"
Cleaning up children...
SignalException 2 "SIGINT"
Cleaning up children...
SignalException 2 "SIGINT"
Cleaning up children...
SignalException 2 "SIGINT"
Cleaning up children...
SignalException 2 "SIGINT"
Cleaning up children...
SignalException 2 "SIGINT"
Cleaning up children...
SignalException 2 "SIGINT"
Cleaning up children...
Terminating with SignalException 2 "SIGINT"

Solution

  • Here is a small example.

    You may want to post your own minimal example to help diagnose your particular issue.

    Could it be that your main thread is terminating without waiting for its children?

    import Control.Concurrent (threadDelay, myThreadId)
    import Control.Concurrent.Async
    import Control.Exception
    import Data.Foldable (for_)
    import Data.Traversable (for)
    import System.Posix.Signals (installHandler, sigTERM, Handler(..))
    
    data Result = Done | Aborted deriving Show
    
    thread :: IO Result
    thread = job `catch` handler
      where
        job = do
          threadDelay 5000000
          pure Done
        handler AsyncCancelled = do
          -- additional clean up can be done here
          pure Aborted
    
    main :: IO ()
    main = do
      -- install handler for SIGTERM: throw UserInterrupt to main thread
      -- (SIGINT is already installed by default)
      mainId <- myThreadId
      installHandler sigTERM (Catch (throwTo mainId UserInterrupt)) Nothing
    
      -- spawn threads
      children <- for [0..9] $ \_ ->
        async thread
    
      -- wait for threads to terminate
      let waitAll = do
            for_ children $ \ t -> do
              wait t
              pure ()
            putStrLn "Normal termination"
      waitAll `catch` \e -> case e of
        UserInterrupt -> do
          putStrLn "Killed."
          putStrLn "Cleaning up..."
          for_ children $ \ t ->
            cancel t
          putStrLn "Waiting on children"
          results <- for children $ \ t ->
            wait t
          putStrLn ("Job results: " ++ show results)
        e -> throwIO e
    
    

    Output after SIGINT or SIGTERM:

    ^CKilled.
    Cleaning up...
    Waiting on children
    Job results: [Aborted,Aborted,Aborted,Aborted,Aborted,Aborted,Aborted,Aborted,Aborted,Aborted]
    

    Update

    Your modified example creates threads dynamically (you don't know ahead of time how many threads you will need) in a loop that is meant to listen for jobs. That complicates the structure of the program a little. Below is a fixed version.

    1. In the parent thread, try listen waits for either a job or an async exception. We pattern-match on the result outside of the exception handler try; we either keep looping without growing the stack (run is tail-recursive), or get an exception and clean up the children threads.
    2. mask_ makes it so that only the listen part may raise an async exception. (Here it's not necessary to restore the mask on listen because it uses a blocking operation, which already unmasks async exceptions.)
    3. Make sure your clean up actually waits for the children.
    import Control.Concurrent (newChan, writeChan, readChan)
    import Control.Concurrent.Async (Async, async, withAsync, asyncWithUnmask, wait, cancelWith)
    import Control.Exception (AsyncException (..), Exception, SomeException, catch, try, throwIO, mask_)
    import System.Posix (Signal)
    import System.Posix.Signals (Handler (..), installHandler, sigHUP, sigINT, sigTERM, sigUSR1, sigUSR2, sigXCPU, sigXFSZ)
    import Control.Concurrent (myThreadId, threadDelay, throwTo)
    import Control.Monad (forM_)
    import Data.Data (Typeable)
    import Data.Foldable (for_)
    
    data Result = Done | Aborted deriving (Show)
    
    termMsg :: Int -> Result -> IO ()
    termMsg n s = putStrLn $ "Thread " ++ show n ++ " terminated with " ++ show s
    
    data Job = Job Int
      deriving Show
    
    thread :: Job -> IO ()
    thread (Job n) = job `catch` asyncHandler
      where
        job = do
          for_ ([0 .. 9] :: [Int]) $ \_ -> do
            putStrLn $ "Thread " ++ show n ++ " alive..."
            threadDelay $ 5000000 * n
          termMsg n Done
    
        asyncHandler :: AsyncException -> IO ()
        asyncHandler _ = do
          -- cleanup
          termMsg n Aborted
    
    parent :: IO Job -> IO ()
    parent listen = mask_ $ run []
      where
        run :: [Async ()] -> IO ()
        run w = do
          event <- try listen
          putStrLn $ "Main thread alive (loop " ++ show event ++ ")"
          case event :: Either SignalException Job of
            Right job -> do
              t <- asyncWithUnmask $ \unmask -> unmask (thread job)
              let nw = w ++ [t]
              run nw
            Left e -> do
              cleanupChildren w
              throwIO e
    
    main :: IO ()
    main = do
      installSignalHandlers
      -- `parent` waits for jobs by calling its `IO job` argument.
      -- We implement it here by reading from a channel which gets populated by the createJobs thread below.
      chan <- newChan
      let listen = readChan chan
          createJobs = for_ [1 .. 9] $ \i -> do
            threadDelay 1000000
            writeChan chan (Job i)
      withAsync createJobs $ \_ ->
        parent listen
    
    cleanupChildren :: [Async ()] -> IO ()
    cleanupChildren children = do
      putStrLn "Cleaning up children..."
      for_ children $ \t -> cancelWith t ThreadKilled
      -- Wait for the children to terminate their own cleanup
      for_ children $ \t -> wait t >> pure ()
    
    data SignalException = SignalException Signal String
      deriving (Show, Typeable, Eq)
    instance Exception SignalException
    
    signalsToHandle :: [(Signal, String)]
    signalsToHandle = [(sigHUP, "SIGHUP"), (sigINT, "SIGINT"), (sigTERM, "SIGTERM"), (sigUSR1, "SIGUSR1"), (sigUSR2, "SIGUSR2"), (sigXCPU, "SIGXCPU"), (sigXFSZ, "SIGXFSZ")]
    
    installSignalHandlers :: IO ()
    installSignalHandlers = do
      mainId <- myThreadId
      forM_ signalsToHandle $ \(sig, name) -> installHandler sig (Catch (throwTo mainId $ SignalException sig name)) Nothing