Defining monad from scratch in Haskell

After studying about monads in Haskell — a subject that is very compelling for everything that implies — I wonder if I could define a monad on my own without using the already defined typeclasses.

Instead of making Monad an instance of Functor, I just want to define a monad per se, with it's own fmap function (Also I wanted to change some function names such as return and call it unit).

A monad may be defined by the bind operator (>>=) and the function return, but it also may be defined in terms of return and join since this last function it can be expressed in terms of the bind operator: join m = m >>= id. So a monad could be (technically) defined in terms of return and join and nothing else. The function fmap is required (and the base of existence of a Functor in Haskell), but also could be defined in terms of return, for it may be defined also (I think) as follows: fmap f m = m >>= return . f (before edit it was written, fmap f m = return . f; it was obviously a typo).

Yet I know this wouldn't be as efficient as using the predefined Monad typeclass, it's just to understand better the Haskell language.

How can I accomplish that? This is a depiction of this concept out of my head, right now, so it's not useful code:

-- Just a sketch
infixr 9 ∘
(∘) :: (b -> c) -> (a -> b) -> a -> c
(∘) g f x = g (f x)
--(f ∘ g) x = f (g x)

-- My own 'fmap'
--mapper id  =  id
--mapper (f ∘ g) = mapper f ∘ mapper g

-- My monad
class MyMonadBase (m :: * -> *) where
    unit :: a -> m a   --return

    join :: m (m a) -> m a
    join = (>>= id)

    mapper f m = m >>= unit ∘ f

data Tree a = Leaf a | Branch (Tree a) (Tree a)

instance MyMonadBase Tree where
    unit = Leaf
    join (Leaf x) = x
    join (Branch l r)  = Branch (join l) (join r)

Am I in the right track (conceptually)?


  • Okay, it wasn't that difficult. I had the misconception that implementing a own monad type will be extremely complex, but it's just applying the definition.

    -- My monad
    class MyMonad m where
        unit :: a -> m a
        join :: m (m a) -> m a
        mapf :: (a -> b) -> m a -> m b
    --Testing MyMonad
    data Tree a = Leaf a | Branch (Tree a) (Tree a) deriving (Show)
    instance MyMonad Tree where
        unit = Leaf
        join (Leaf x) = x
        join (Branch l r) = Branch (join l) (join r)
        mapf f (Leaf x) = Leaf (f x)
        mapf f (Branch l r) = Branch (mapf f l) (mapf f r)
    t = Branch (Branch (Leaf 1) (Leaf 3)) (Branch (Leaf 2) (Leaf 4))
    -- My bind (just for completeness, not that I need it for this example)
    (>>>) :: MyMonad m => m a -> (a -> m b) -> m b
    xs >>> f = join (mapf f xs)
    -- Testing my bind
    extr :: Integer -> Tree Integer
    extr x = Branch (Leaf (x^2)) (Leaf (2^x))
    t >>> extr
    --Branch (Branch (Branch (Leaf 1) (Leaf 2)) (Branch (Leaf 9) (Leaf 8))) 
    --       (Branch (Branch (Leaf 4) (Leaf 4)) (Branch (Leaf 16) (Leaf 16)))