Search code examples
haskellpointfree

Help in understanding pointfree code


When playing around with Pointfree I was presented with a piece of code that I can't seem to understand.

:pl map (\x -> x * x) [1..10]
-- map (join (*)) [1..10]

My main problem is that I don't get how join works here. I understand that it 'removes' one layer of a monadic wrapping (m (m a) to m a). I figure it boils down to something like [1..10] >>= (\x -> [x * x]), but I don't really get how the "extra layer" gets introduced. I get that join x = x >>= id, but then I'm still stuck on how that "duplicates" each value so that (*) gets two arguments. This has been bugging me for about half an hour now and I'm mostly annoyed at myself, because I feel like I have all the puzzle pieces but can't seem to fit them together...

P.S. Don't worry, I would't really use this pointfree version, this is pure curiosity and an attempt to understand Haskell better.


Solution

  • join is using the instance of Monad for (->) a, as defined in Control.Monad.Instances. The instance is similar to Reader, but without an explicit wrapper. It is defined like this:

    instance Monad ((->) a) where
      -- return :: b -> (a -> b)
      return = const
      -- (>>=) :: (a -> b) -> (b -> a -> c) -> (a -> c)
      f >>= g = \x -> g (f x) x
    

    If you now reduce join using this instance:

    join
    (>>= id)
    flip (\f g x -> g (f x) x) (\a -> a)
    (\f x -> (\a -> a) (f x) x)
    (\f x -> f x x)
    

    As you can see, the instance for (->) a makes join to a function that applies an argument twice. Because of this, join (*) is simply \x -> x * x.