Search code examples
haskellexceptionmonadsstate-monadio-monad

How to preserve the state of the monad stack in the IO exception handler?


Consider the following program.

import Control.Monad.State
import Control.Monad.Catch

ex1 :: StateT Int IO ()
ex1 = do
    modify (+10)
    liftIO . ioError $ userError "something went wrong"

ex2 :: StateT Int IO ()
ex2 = do
    x <- get
    liftIO $ print x

ex3 :: StateT Int IO ()
ex3 = ex1 `onException` ex2

main :: IO ()
main = evalStateT ex3 0

When we run the program we get the following output.

$ runhaskell Test.hs
0
Test.hs: user error (something went wrong)

However, I expected the output to be as follows.

$ runhaskell Test.hs
10
Test.hs: user error (something went wrong)

How do I preserve the intermediate state in ex1 in the exception handler ex2?


Solution

  • Use an IORef (or MVar or TVar or whatever) instead.

    newtype IOStateT s m a = IOStateT { unIOStateT :: ReaderT (IORef s) m a }
        deriving (Functor, Applicative, Monad, MonadTrans, MonadIO)
        -- N.B. not MonadReader! you want that instance to pass through,
        -- unlike ReaderT's instance, so you have to write the instance
        -- by hand
    
    runIOStateT :: IOStateT s m a -> IORef s -> m a
    runIOStateT = runReaderT . unIOStateT -- or runIOStateT = coerce if you're feeling cheeky
    
    instance MonadIO m => MonadState s (IOStateT s m) where
        state f = IOStateT $ do
            ref <- ask
            liftIO $ do
                s <- readIORef ref
                let (a, s') = f s
                writeIORef ref s'
                pure a
    

    This feels like a pattern I've seen enough times that there ought to be a Hackage package for it, but I don't know of one.