Search code examples
haskellmonadsdo-notation

Understanding do notation for simple Reader monad: a <- (*2), b <- (+10), return (a+b)


instance Monad ((->) r) where  
    return x = \_ -> x  
    h >>= f = \w -> f (h w) w  

import Control.Monad.Instances  

addStuff :: Int -> Int  
addStuff = do  
    a <- (*2)  
    b <- (+10)  
    return (a+b)  

I'm trying to understand this monad by unwiding the do notation, because I think the do notation hides what happens.

If I understood correctly, this is what happens:

(*2) >>= (\a -> (+10) >>= (\b -> return (a+b))) 

Now, if we take the rule for >>=, we must understand (*2) as h and (\a -> (+10) >>= (\b -> return (a+b))) as f. Applying h to w is easy, let's just say it is 2w (I don't know if 2w is valid in haskell but just for reasoning lets keep it this way. Now we have to apply f to h w or 2w. Well, f simply returns (+10) >>= (\b -> return (a+b)) for an specific a, which is 2w in our case, so f (hw) is (+10) >>= (\b -> return (2w+b)). We must first get what happens to (+10) >>= (\b -> return (2w + b)) before finally applying it to w.

Now we reidentify (+10) >>= (\b -> return (2w + b)) with our rule, so h is +10 and f is (\b -> return (2w + b)). Let's first do h w. We get w + 10. Now we need to apply f to h w. We get (return (2w + w + 10)).

So (return (2w + w + 10)) is what we need to apply to w in the first >>= that we were tyring to uwind. But I'm totally lost and I don't know what happened.

Am I thinking in the rigth way? This is so confusing. Is there a better way to think of it?


Solution

  • You're forgetting that operator >>= doesn't return just f (h w) w, but rather \w -> f (h w) w. That is, it returns a function, not a number.

    By substituting it incorrectly you lost the outermost parameter w, so it's no wonder it remains free in your final expression.

    To do this correctly, you have to substitute function bodies for their calls completely, without dropping stuff.

    If you substitute the outermost >>=, you will get:

    (*2) >>= (\a -> ...) 
    ==
    \w -> (\a -> ...) (w*2) w
    

    Then, if you substitute the innermost >>=, you get:

    \a -> (+10) >>= (\b -> return (a+b))
    ==
    \a -> \w1 -> (\b -> return (a+b)) (w1 + 10) w1
    

    Note that I use w1 instead of w. This is to avoid name collisions later on when I combine the substitutions, because these two ws come from two different lambda abstractions, so they're different variables.

    Finally, substitute the return:

    return (a+b)
    ==
    \_ -> a+b
    

    Now insert this last substitution into the previous one:

    \a -> (+10) >>= (\b -> return (a+b))
    ==
    \a -> \w1 -> (\b -> return (a+b)) (w1 + 10) w1
    ==
    \a -> \w1 -> (\b -> \_ -> a+b) (w1 + 10) w1
    

    And finally insert this into the very first substitution:

    (*2) >>= (\a -> ...) 
    ==
    \w -> (\a -> ...) (w*2) w
    ==
    \w -> (\a -> \w1 -> (\b -> \_ -> a+b) (w1 + 10) w1) (w*2) w
    

    And now that all substitutions are compete, we can reduce. Start with applying the innermost lambda \b -> ...:

    \w -> (\a -> \w1 -> (\_ -> a+w1+10) w1) (w*2) w
    

    Now apply the new innermost lambda \_ -> ...:

    \w -> (\a -> \w1 -> a+w1+10) (w*2) w
    

    Now apply \a -> ...:

    \w -> (\w1 -> w*2+w1+10) w
    

    And finally apply the only remaining lambda \w1 -> ...:

    \w -> w*2+w+10
    

    And voila! The whole function reduces to \w -> (w*2) + (w+10), completely as expected.