Search code examples
haskellstate-monad

State and error monad stack with state rollback on error


I have a problem that I will illustrate through the following example:

Let's say I want to do some computations that can yield a result or an error, while carrying a state. For that, I have the following monad stack:

import Control.Monad.Trans.State ( get, modify, State )
import Control.Monad.Trans.Except ( catchE, throwE, ExceptT )

type MyMonad a = ExceptT String (State [Int]) a

So, the state is a list of ints, errors are strings and computations can yield a value of any type "a". I can do things like:

putNumber :: Int -> MyMonad ()
putNumber i = lift $ modify (i:)

Now, suppose I defined a function that adds the half of the last number to the state:

putHalf :: MyMonad ()
putHalf = do
  s <- lift get
  case s of
    (x:_) -> if even x then putNumber (div x 2) else throwE "Number can't be halved"
    [] -> throwE "The state is empty"

Using putHalf will either add a number to the state and return void, or yield any of the two errors.

If an error occurs, I would like to be able to call an alternative function. I know I can achieve this with catchE by doing something like this:

putWithAlternative :: MyMonad ()
putWithAlternative = putHalf `catchE` (\_ -> putNumber 12)

In this case, if putHalf fails for any reason, the number 12 will be added to the state. Up to this point everything is fine. However, I could define a function that called putHalf twice:

putHalfTwice :: MyMonad ()
putHalfTwice = putHalf >> putHalf

The problem is that if, for example, the state contained only number 2, the first call to putHalf would succeed and modify the state, but the second one would fail. I need putHalfTwice to do both calls and modify the state twice, or none at all and leave the state as it is. I can't use catchE or putWithAlternative, because the state is still modified in the first call.

I know the Parsec library does this through its <|> and try operators. How could I go about defining these myself? Is there any already defined monad transformer that could achieve this?


Solution

  • If, in your problem domain, failure should never modify the state, the most straightforward thing to do is to invert the layers:

    type MyMonad' a = StateT [Int] (Except String) a
    

    Your original monad is isomorphic to:

    s -> (Either e a, s)
    

    so it always returns a new state, whether it succeeds or fails. This new monad is isomorphic to:

    s -> Either e (a, s)
    

    so it either fails or returns a new state.

    The following program recovers from putHalfTwice without mangling the state:

    import Control.Monad.Trans
    import Control.Monad.Trans.State
    import Control.Monad.Trans.Except
    
    type MyMonad' a = StateT [Int] (Except String) a
    
    putNumber :: Int -> MyMonad' ()
    putNumber i = modify (i:)
    
    putHalf :: MyMonad' ()
    putHalf = do
      s <- get
      case s of
        (x:_) -> if even x then putNumber (div x 2) else lift $ throwE "Number can't be halved"
        [] -> lift $ throwE "the state is empty"
    
    putHalfTwice :: MyMonad' ()
    putHalfTwice = putHalf >> putHalf
    
    foo :: MyMonad' ()
    foo = liftCatch catchE putHalfTwice (\_ -> putNumber 12)
    
    main :: IO ()
    main = do
      print $ runExcept (runStateT foo [2])
    

    Otherwise, if you want backtracking to be optional, then you can write your own try that catches, restores the state, and rethrows:

    try :: MyMonad a -> MyMonad a
    try act = do
      s <- lift get
      act `catchE` (\e -> lift (put s) >> throwE e)
    

    and then:

    import Control.Monad.Trans
    import Control.Monad.Trans.State
    import Control.Monad.Trans.Except
    
    type MyMonad a = ExceptT String (State [Int]) a
    
    putNumber :: Int -> MyMonad ()
    putNumber i = lift $ modify (i:)
    
    putHalf :: MyMonad ()
    putHalf = do
      s <- lift get
      case s of
        (x:_) -> if even x then putNumber (div x 2) else throwE "Number can't be halved"
        [] -> throwE "The state is empty"
    
    putHalfTwice :: MyMonad ()
    putHalfTwice = putHalf >> putHalf
    
    try :: MyMonad a -> MyMonad a
    try act = do
      s <- lift get
      act `catchE` (\e -> lift (put s) >> throwE e)
    
    foo :: MyMonad ()
    foo = putHalfTwice `catchE` (\_ -> putNumber 12)
    
    bar :: MyMonad ()
    bar = try putHalfTwice `catchE` (\_ -> putNumber 12)
    
    main :: IO ()
    main = do
      print $ runState (runExceptT foo) [2]
      print $ runState (runExceptT bar) [2]