Search code examples
haskelltypesmonadsfunctor

Why does join . (flip fmap) have type ((A -> B) -> A) -> (A -> B) -> B?


Some playing around with functors and monads in ghci led me to a value whose type and behaviour I would like to understand better.

The type of \x -> join . x is (Monad m) => (a -> m (m b)) -> (a -> m b) and the type of \y -> y . (flip fmap) is (Functor f) => ((a -> b) -> f b) -> (f a -> c).

Version 8.2.2 of ghci permits the definition h = join . (flip fmap).

Why does h have type ((A -> B) -> A) -> (A -> B) -> B?

In particular, why do the functor and monad constraints disappear? Is this really the correct and expected behaviour? As a follow up, I would also like to ask:

Why does evaluating h (\f -> f u) (\x -> x + v) for integers u and v give u + 2v in every case?


Solution

  • In short: due to type deduction, Haskell knows that m and f are in fact a partially instantiated arrow.

    Deriving the type

    Well let us do the math. The function join . (flip fmap) is basically your given lambda expression \x -> join . x with as argument (flip fmap), so:

    h = (\x -> join . x) (flip fmap)
    

    Now the lambda expression has type:

    (\x -> join . x) :: Monad m =>   (a -> m (m b)) -> (a -> m b)
    

    Now the argument flip fmap has type:

    flip fmap        :: Functor f => f c -> ((c -> d) -> f d)
    

    (we here use c and d instead of a and b to avoid confusion between two possibly different types).

    So that means that the type of flip fmap is the same as the type of the argument of the lambda expression, hence we know that:

      Monad m =>   a   -> m (m b)
    ~ Functor f => f c -> ((c -> d) -> f d)
    ---------------------------------------
    a ~ f c, m (m b) ~ ((c -> d) -> f d)
    

    So we now know that a has the same type as f c (this is the meaning of the tilde ~).

    But we have to do some extra computations:

      Monad m =>   m (m b)
    ~ Functor f => ((c -> d) -> f d)
    --------------------------------
    m ~ (->) (c -> d), m b ~ f d
    

    Hence we know that m is the same as (->) (c -> d) (basically this is a function where we know that input type, here (c -> d), and the output type is a type parameter of m.

    So that means that m b ~ (c -> d) -> b ~ f d, so this means that f ~ (->) (c -> d) and b ~ d. An extra consequence is that since a ~ f c, we know that a ~ (c -> d) -> c

    So to list what we derived:

    f ~ m
    m ~ (->) (c -> d)
    b ~ d
    a ~ (c -> d) -> c
    

    So we now can "specialize" the types of both our lambda expression, and our flip fmap function:

    (\x -> join . x)
        :: (((c -> d) -> c) -> (c -> d) -> (c -> d) -> d) -> ((c -> d) -> c) -> (c -> d) -> d
    flip fmap
        ::  ((c -> d) -> c) -> (c -> d) -> (c -> d) -> d
    

    and type of flip fmap now perfectly matches with the type of the argument of the lambda expression. So the type of (\x -> join . x) (flip fmap) is the result type of the lambda expression type, and that is:

    (\x -> join . x) (flip fmap)
        :: ((c -> d) -> c) -> (c -> d) -> d
    

    But now we of course did not yet obtained the implementation of this function. We are however already a step further.

    Deriving the implementation

    Since we now know that m ~ (->) (c -> d), we know we should lookup the arrow instance of a monad:

    instance Monad ((->) r) where
        f >>= k = \ r -> k (f r) r
    

    So for a given function f :: r -> a, as left operand, and a function k :: a -> (r -> b) ~ a -> r -> b as operand, we construct a new function that maps a variable x to k applied to f applied to x, and x. It is thus a way to perform some sort of preprocessing on an input variable x, and then do the processing both taking into account the preprocessing and the original view (well this is an interpretation a human reader can use).

    Now join :: Monad m => m (m a) -> m a is implemented as:

    join :: Monad m => m (m a) -> m a
    join x = x >>= id
    

    So for the (->) r monad, this means that we implement this as:

    -- specialized for `m ~ (->) a
    join f = \r -> id (f r) r
    

    Since id :: a -> a (the identity function) returns its argument, we can further simplify it to:

    -- specialized for `m ~ (->) a
    join f = \r -> (f r) r
    

    or cleaner:

    -- specialized for `m ~ (->) a
    join f x = f x x
    

    So it basically is given a function f, and will then apply an argument twice to that function.

    Furthermore we know that the Functor instance for the arrow type is defined as:

    instance Functor ((->) r) where
        fmap = (.)
    

    So it is basically used as a "post processor" on the result of the function: we construct a new function that will do the post processing with the given function.

    So now that we specialized the function enough for the given Functor/Monad, we can derive the implementation as:

    -- alternative implementation
    h = (.) (\f x -> f x x) (flip (.))
    

    or by using more lambda expressions:

    h = \a -> (\f x -> f x x) ((flip (.)) a)
    

    which we can now further specialize as:

    h = \a -> (\f x -> f x x) ((\y z -> z . y) a)
    
    -- apply a in the lambda expression
    h = \a -> (\f x -> f x x) (\z -> z . a)
    
    -- apply (\z -> z . a) in the first lambda expression
    h = \a -> (\x -> (\z -> z . a) x x)
    
    -- cleaning syntax
    h a = (\x -> (\z -> z . a) x x)
    
    -- cleaning syntax
    h a x = (\z -> z . a) x x
    
    -- apply lambda expression
    h a x = (x . a) x
    
    -- remove the (.) part
    h a x = x (a x)
    

    So h basically takes two arguments: a and x, it then performs function application with a as function and x as parameter, and the output is passed to the x function again.

    Sample usage

    As sample usage you use:

    h (\f -> f u) (\x -> x + v)
    

    or nicer:

    h (\f -> f u) (+v)
    

    so we can analyze this like:

       h (\f -> f u) (+v)
    -> (+v) ((\f -> f u) (+v))
    -> (+v) ((+v) u)
    -> (+v) (u+v)
    -> ((u+v)+v)
    

    So we add u+v to v.