When playing around with Pointfree I was presented with a piece of code that I can't seem to understand.
:pl map (\x -> x * x) [1..10]
-- map (join (*)) [1..10]
My main problem is that I don't get how join
works here. I understand that it 'removes' one layer of a monadic wrapping (m (m a)
to m a
). I figure it boils down to something like [1..10] >>= (\x -> [x * x])
, but I don't really get how the "extra layer" gets introduced. I get that join x = x >>= id
, but then I'm still stuck on how that "duplicates" each value so that (*)
gets two arguments. This has been bugging me for about half an hour now and I'm mostly annoyed at myself, because I feel like I have all the puzzle pieces but can't seem to fit them together...
P.S. Don't worry, I would't really use this pointfree version, this is pure curiosity and an attempt to understand Haskell better.
join
is using the instance of Monad
for (->) a
, as defined in Control.Monad.Instances
. The instance is similar to Reader
, but without an explicit wrapper. It is defined like this:
instance Monad ((->) a) where
-- return :: b -> (a -> b)
return = const
-- (>>=) :: (a -> b) -> (b -> a -> c) -> (a -> c)
f >>= g = \x -> g (f x) x
If you now reduce join
using this instance:
join
(>>= id)
flip (\f g x -> g (f x) x) (\a -> a)
(\f x -> (\a -> a) (f x) x)
(\f x -> f x x)
As you can see, the instance for (->) a
makes join
to a function that applies an argument twice. Because of this, join (*)
is simply \x -> x * x
.