Search code examples
haskellmonadsmonad-transformers

How to refactor code handling IO (Maybe a) to use monad transformer?


I'm trying to understand monad transformers. I have a code like this (that doesn't work):

import System.IO (hFlush, stdout)
import Text.Read (readMaybe)

add1 :: Int -> IO (Maybe Int)
add1 x = return $ Just (x + 1)

readNumber :: IO (Maybe Int)
readNumber = do
  putStr "Say a number: "
  hFlush stdout
  inp <- getLine
  return $ (readMaybe inp :: Maybe Int)

main :: IO ()
main = do
  x <- readNumber >>= add1
  print x

It throws

Couldn't match type ‘Int’ with ‘Maybe Int’
      Expected: Maybe Int -> IO (Maybe Int)
        Actual: Int -> IO (Maybe Int)

I figured out that I can make it work by introducing

(>>>=) :: IO (Maybe a) -> (a -> IO (Maybe b)) -> IO (Maybe b)
x >>>= f =
  x >>= go f
  where
    go _ Nothing = return Nothing
    go f (Just x) = f x

and using it instead of >>=. This is strikingly similar to a monad transformer, but I can't get my head around how exactly I should refactor this code to use it.

You may wonder "why does add1 return IO?" Let's say that it can be something more complicated that uses IO.

I'm looking to understand it better, so answers like "there is a better solution" or "it is already implemented in..." won't help. I would like to learn what I would need to change to make it work with >>= assuming that I want to do operations like IO (Maybe a) -> (a -> IO (Maybe b)) -> IO (Maybe b) that already work with my >>>=.


Solution

  • I'd say the most common way to use monad transformers is the mtl approach. That consists of using type classes like MonadIO and MonadFail to implement your programs and then in your main function use concrete transformers like MaybeT to instantiate the type classes and get the actual result.

    For your program that can look like this:

    import System.IO (hFlush, stdout)
    import Text.Read (readMaybe)
    import Control.Monad.Trans.Maybe (runMaybeT)
    import Control.Monad.IO.Class (MonadIO (liftIO))
    import Control.Monad (MonadFail (fail))
    
    add1 :: Monad m => Int -> m Int
    add1 x = pure (x + 1)
    
    prompt :: String -> IO String
    prompt x = do
      putStr x
      hFlush stdout
      getLine
    
    readNumber :: (MonadIO m, MonadFail m) => m Int
    readNumber = do
      inp <- liftIO (prompt "Say a number: ")
      case readMaybe inp of
        Nothing -> fail "Not a number"
        Just x -> pure x
    
    main :: IO ()
    main = do
      x <- runMaybeT (readNumber >>= add1)
      print x