Search code examples
haskellmonadsstate-monad

Haskell instance of `bind` for a custom type


I'm trying to create an instance for bind operator (>>=) to the custom type ST a

I found this way to do it but I don't like that hardcoded 0.

Is there any way to implement it without having the hardcoded 0 and respecting the type of the function?

newtype ST a = S (Int -> (a, Int))
    
-- This may be useful to implement ">>=" (bind), but it is not mandatory to use it
runState :: ST a -> Int -> (a, Int)
runState (S s) = s
        
instance Monad ST where
      return :: a -> ST a
      return x = S (\n -> (x, n))
       
      (>>=) :: ST a -> (a -> ST b) -> ST b
      s >>= f = f (fst (runState s 0))

Solution

  • I often find it easier to follow such code with a certain type of a pseudocode rewrite, like this: starting with the

    instance Monad ST where
          return :: a -> ST a
          return x = S (\n -> (x, n))
    

    we get to the

      runState (return x) n = (x, n)
    

    which expresses the same thing exactly. It is now a kind of a definition through an interaction law that it must follow. This allows me to ignore the "noise"/wrapping around the essential stuff.

    Similarly, then, we have

          (>>=) :: ST a -> (a -> ST b) -> ST b
          s >>= f = -- f (fst (runState s 0))   -- nah, 0? what's that?
          -- 
          -- runState (s >>= f) n = runState (f a) i where 
          --                                   (a, i) = runState s n
          --
                    S $       \ n ->       let (a, i) = runState s n in
                                    runState (f a) i
    

    because now we have an Int in sight (i.e. in scope), n, that will get provided to us when the combined computation s >>= f will "run". I mean, when it will runState.

    Of course nothing actually runs until called upon from main. But it can be a helpful metaphor to hold in mind.

    The way we've defined it is both the easiest and the most general, which is usually the way to go. There are more ways to make the types fit though.

    One is to use n twice, in the input to the second runState as well, but this will leave the i hanging unused.

    Another way is to flip the time arrow around w.r.t. the state passing, with

                    S $       \ n ->       let (a, i2) = runState s i 
                                               (b, i ) = runState (f a) n
                                            in (b, i2)
    

    which is a bit weird to say the least. s still runs first (as expected for the s >>= f combination) to produce the value a from which f creates the second computation stage, but the state is being passed around in the opposite direction.