Search code examples
haskellmemoizationmonad-transformersstate-monadfree-monad

Haskell: Memoising using MonadState inside a FreeMonad Interpreter


Given that I have the following DSL (using a Free Monad) and its interpreter:

data MyDslF next =
    GetThingById Int (Thing -> next)
  | Log Text next

type MyDslT = FT MyDslF

runMyDsl :: (MonadLogger m, MonadIO m, MonadCatch m) => MyDslT m a -> m a
runMyDsl = iterT run
  where
    run :: (MonadLogger m, MonadIO m, MonadCatch m) => MyDslF (m a) -> m a
    run (Log message continue)      = Logger.log message >> continue
    run (GetThingById id' continue) = SomeApi.getThingById id' >>= continue

I would like to change the interpreter internally to use MonadState so that if a Thing has already been retrieved for a given Id, then there is no second call to SomeApi

Lets assume I already know how to write the memoised version using get and put, but the problem I am having is running the MonadState inside runMyDsl. I was thinking the solution would look similar to:

type ThingMap = Map Int Thing

runMyDsl :: (MonadLogger m, MonadIO m, MonadCatch m) => MyDslT m a -> m a
runMyDsl = flip evalStateT mempty . iterT run
  where
    run :: (MonadLogger m, MonadIO m, MonadCatch m, MonadState ThingMap m) => MyDslF (m a) -> m a
    run ..

But the types do not align, since run returns (.. , MonadState ThingMap m) => m a and evalStateT expects StateT ThingMap m a.


Solution

  • Use iterTM instead of iterT:

    runMyDsl :: (MonadLogger m, MonadIO m, MonadCatch m) => MyDslT m a -> m a
    runMyDsl dsl = evalStateT (iterTM run dsl) Map.empty
      where
      run (Log message continue)      = logger message >> continue
      run (GetThingById id' continue) = do
        m <- get 
        case Map.lookup id' m of
          Nothing -> do 
             thing <- getThingById id' 
             put (Map.insert id' thing m)
             continue thing
          Just thing -> continue thing
    

    Equivalently, you can use iterT if you first raise MyDsl m a to an MyDsl (StateT Int m) a using hoistFT lift, like so:

    runMyDsl :: (MonadLogger m, MonadIO m, MonadCatch m) => MyDslT m a -> m a
    runMyDsl dsl = evalStateT (iterT run (hoistFT lift dsl)) Map.empty
    

    This makes dsl into a MyDsl (StateT Int m) a that doesn't actually involve any state updates, though run does involve state transitions.