Search code examples
haskellmonadsmonad-transformersstate-monad

Lifting a complete monadic action to a transformer (>>= but for Monad Transformers)


I looked hard to see if this may be a duplicate question but couldn't find anything that addressed specifically this. My apologies if there actually is something.

So, I get how lift works, it lifts a monadic action (fully defined) from the outer-most transformer into the transformed monad. Cool.

But what if I want to apply a (>>=) from one level under the transformer into the transformer? I'll explain with an example.

Say MyTrans is a MonadTrans, and there is also an instance Monad m => Monad (MyTrans m). Now, the (>>=) from this instance will have this signature:

instance Monad m => Monad (MyTrans m) where
   (>>=) :: MyTrans m a -> (a -> MyTrans m b) -> MyTrans m b

but what I need is something like this:

(>>=!) :: Monad m => MyTrans m a -> (m a -> MyTrans m b) -> MyTrans m b

In general:

(>>=!) :: (MonadTrans t, Monad m) => t m a -> (m a -> t m b) -> t m b

It looks like a combination of the original (>>=) and lift, except it really isn't. lift can only be used on covariant arguments of type m a to transform them into a t m a, not the other way around. In other words, the following has the wrong type:

(>>=!?) :: Monad m => MyTrans m a -> (a -> m b) -> MyTrans m b
x >>=!? f = x >>= (lift . f)

Of course a general colift :: (MonadTrans t, Monad m) => t m a -> m a makes absolutely zero sense, because surely the transformer is doing something that we cannot just throw away like that in all cases.

But just like (>>=) introduces contravariant arguments into the monad by ensuring that they will always "come back", I thought something along the lines of the (>>=!) function would make sense: Yes, it in some way makes an m a from a t m a, but only because it does all of this within t, just like (>>=) makes an a from an m a in some way.

I've thought about it and I don't think (>>=!) can be in general defined from the available tools. In some sense it is more than what MonadTrans gives. I haven't found any related type classes that offer this either. MFunctor is related but it is a different thing, for changing the inner monad, but not for chaining exclusively transformer-related actions.

By the way, here is an example of why you would want to do this:

EDIT: I tried to present a simple example but I realized that that one could be solved with the regular (>>=) from the transformer. My real example (I think) cannot be solved with this. If you think every case can be solved with the usual (>>=), please do explain how.

Should I just define my own type class for this and give some basic implementations? (I'm interested in StateT, and I'm almost certain it can be implemented for it) Am I doing something in a twisted way? Is there something I overlooked?

Thanks.

EDIT: The answer provided by Fyodor matches the types, but does not do what I want, since by using pure, it is ignoring the monadic effects of the m monad. Here is an example of it giving the wrong answer:

Take t = StateT Int and m = [].

x1 :: StateT Int [] Int
x1 = StateT (\s -> [(1,s),(2,s),(3,s)])

x2 :: StateT Int [] Int
x2 = StateT (\s -> [(1,s),(2,s),(3,s),(4,s))])

f :: [Int] -> StateT Int [] Int
f l = StateT (\s -> if (even s) then [] else (if (even (length l)) then (fmap (\z -> (z,z+s)) l) else [(123,123)]))

runStateT (x1 >>= (\a -> f (pure a))) 1 returns [(123,123),(123,123),(123,123)] as expected, since both 1 is odd and the list in x1 has odd length.

But runStateT (x2 >>= (\a -> f (pure a))) 1 returns [(123,123),(123,123),(123,123),(123,123)], whereas I would have expected it to return [(1,2),(2,3),(3,4),(4,5)], since the 1 is odd and the length of the list is even. Instead, the evaluation of f is happening on the lists [(1,1)], [(2,1)], [(3,1)] and [(4,1)] independently, due to the pure call.


Solution

  • This can be very trivially implemented via bind + pure. Consider the signature:

    (>>=!) :: (Monad m, MonadTrans t) => t m a -> (m a -> t m a) -> t m a
    

    If you use bind on the first argument, you get yourself a naked a, and since m is a Monad, you can trivially turn that naked a into an m a via pure. Therefore, the straightforward implementation would be:

    (>>=!) x f = x >>= \a -> f (pure a)
    

    And because of this, bind is always strictly more powerful than your proposed new operation (>>=!), which is probably the reason it doesn't exist in the standard libraries.


    I think it may be possible to propose more clever interpretations of (>>=!) for some specific transformers or specific underlying monads. For example, if m ~ [], one might imagine passing the whole list as m a instead of its elements one by one, as my generic implementation above would do. But this sort of thing seems too specific to be implemented in general.

    If you have a very specific example of what you're after, and you can show that my above general implementation doesn't work, then perhaps I can provide a better answer.


    Ok, to address your actual problem from the comments:

    I have a function f :: m a -> m b -> m c that I want to transform into a function ff :: StateT s m a -> StateT s m b -> StateT s m c

    I think looking at this example may illustrate the difficulty better. Consider the required signature:

    liftish :: Monad m => (m a -> m b -> m c) -> StateT m a -> StateT m b -> StateT m c
    

    Presumably, you'd want to keep the effects of m that are already "imprinted" within the StateT m a and StateT m b parameters (because if you don't - my simple solution above will work). To do this, you can "unwrap" the StateT via runStateT, which will get you m a and m b respectively, which you can then use to obtain m c:

    liftish f sa sb = do
      s <- get
      let ma = fst <$> runStateT sa s
          mb = fst <$> runStateT sb s
      lift $ f ma mb
    

    But here's the trouble: see those fst <$> in there? They are throwing away the resulting state. The call to runStateT sa s results not only in the m a value, but also in the new, modified state. And same goes for runStateT sb s. And presumably you'd want to get the state that resulted from runStateT sa and pass it to runStateT sb, right? Otherwise you're effectively dropping some state mutations.

    But you can't get to the resulting state of runStateT sa, because it's "wrapped" inside m. Because runStateT returns m (a, s) instead of (m a, s). If you knew how to "unwrap" m, you'd be fine, but you don't. So the only way to get that intermediate state is to run the effects of m:

    liftish f sa sb = do
      s <- get
      (c, s'') <- lift $ do
        let ma = runStateT sa s
        (_, s') <- ma
        let mb = runStateT sb s'
        (_, s'') <- mb
        c <- f (fst <$> ma) (fst <$> mb)
        pure (c, s'')
      put s''
      pure c
    

    But now see what happens: I'm using ma and mb twice: once to get the new states out of them, and second time by passing them to f. This may lead to double-running effects or worse.

    This problem of "double execution" will, I think, show up for any monad transformer, simply because the transformer's effects are always wrapped inside the underlying monad, so you have a choice: either drop the transformer's effects or execute the underlying monad's effects twice.