Search code examples
haskelltypesmonadsmonad-transformersstate-monad

Why does a working monadic linkage between functions break, when I introduce a new argument that uses the Monad as a parameter?


The following code:

-- tst2.hs - showing successful monadic linkage.

{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE FlexibleContexts #-}

import Protolude
import Control.Monad.Extra  (unfoldM)

foo :: Monad m
    => Int
    -> m [[Int]]
foo n = evalStateT (traverse nxt [1..n]) 0
 where nxt _ = do s <- get
                  r <- bar s
                  put $ s + 1
                  return r

bar :: Monad m
    => Int
    -> m [Int]
bar n = unfoldM step n
  where step k = return $ if k > 0 then Just (k, k - 1)
                                   else Nothing

main :: IO ()
main = do xs <- foo 3
          print xs

works fine, producing this output:

Davids-Air-2:haskell-rl dbanas$ stack runghc tst2.hs 
[[],[1],[2,1]]

However, if I change the code slightly:

-- tst3.hs - showing breakage of monadic linkage.

{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE FlexibleContexts #-}

import Protolude
import Control.Monad.Extra  (unfoldM)

data Dummy m = Dummy (Int -> m Int)

foo :: Monad m
    => Int
    -> Dummy m
    -> m [[Int]]
foo n d = evalStateT (traverse nxt [1..n]) 0
 where nxt _ = do s <- get
                  r <- bar s d
                  put $ s + 1
                  return r

bar :: Monad m
    => Int
    -> Dummy m
    -> m [Int]
bar n d = unfoldM step n
  where step k = return $ if k > 0 then Just (k, k - 1)
                                   else Nothing

main :: IO ()
main = do xs <- foo 3 $ Dummy (\n -> return n)
          print xs

I get the following compiler error:

Davids-Air-2:haskell-rl dbanas$ stack runghc tst3.hs 

tst3.hs:15:23: error:
    • Couldn't match type ‘m’ with ‘StateT Integer m’
      ‘m’ is a rigid type variable bound by
        the type signature for:
          foo :: forall (m :: * -> *). Monad m => Int -> Dummy m -> m [[Int]]
        at tst3.hs:(11,1)-(14,16)
      Expected type: StateT Integer m [[Int]]
        Actual type: m [[Int]]
    • In the first argument of ‘evalStateT’, namely
        ‘(traverse nxt [1 .. n])’
      In the expression: evalStateT (traverse nxt [1 .. n]) 0
      In an equation for ‘foo’:
          foo n d
            = evalStateT (traverse nxt [1 .. n]) 0
            where
                nxt _
                  = do s <- get
                       ....
    • Relevant bindings include
        nxt :: forall p. p -> m [Int] (bound at tst3.hs:16:8)
        d :: Dummy m (bound at tst3.hs:15:7)
        foo :: Int -> Dummy m -> m [[Int]] (bound at tst3.hs:15:1)
   |
15 | foo n d = evalStateT (traverse nxt [1..n]) 0
   |                       ^^^^^^^^^^^^^^^^^^^

The solution turns out to be very simple, although it took me several hours to deduce:

          r <- lift $ bar s d

Davids-Air-2:haskell-rl dbanas$ stack runghc tst3.hs 
[[],[1],[2,1]]

I have a theory about this and would like confirmation:

The Monad operating in bar is different from the Monad operating in foo. Specifically, the Monad operating in foo is the lifted version of the Monad operating in bar. (Lifted into a StateT context, to be precise.)

In the case of tst2.hs where there is no type linkage between the two ms (of foo and bar), this works perfectly. However, in tst3.hs I've provided a type linkage between the two forcing them to be the same Monad. And that is why the compiler complains.

My solution works, only because the Monad operating in foo really is the lifted version of the one operating in bar. If the two Monads were completely unrelated then my solution would not work.

Is that all correct?


Solution

  • When you call foo from main, you are saying that Monad m should be IO. Thus, you get a Dummy IO from Dummy (\n -> return n) (this could just be Dummy return, by the way). Then inside foo, you are calling bar with a Dummy IO argument, thus setting m to IO in bar's Monad m. However, the call to bar is inside StateT Int IO instead of IO, hence the error.

    As you discovered, you can call bar with Dummy IO and then lift the result into StateT Int IO. You are also correct to observe that there could be situations where this wouldn't work.

    There's another solution that you may want to consider. If you don't actually care what monad your Dummy type needs (as seems to be the case), you can enforce that it works in all monads:

    newtype Dummy = Dummy (Int -> (forall m. Monad m => m Int))
    

    This requires the RankNTypes extension.