Search code examples
haskellstatemonadsstate-monad

Implementing factorial and fibonacci using State monad (as a learning exercise)


I worked my way through Mike Vanier's monad tutorial (which is excellent) and I'm working on a few of the exercises in his post on how to use a "State" monad.

In particular, he suggests an exercise which consists of writing functions for factorial and fibonacci using a State monad. I gave it a shot and came up with the answers below. (I find do notation pretty confusing, hence my choice of syntax).

Neither of my implementations look particularly "Haskell-y" and, in the interest of not internalizing bad practices, I thought I'd ask folks for input on how they would've gone about implementing these functions (using the state monad). Is it possibly to write this code far more simply (aside from switching to do notation)? I strongly suspect this is the case.


I'm aware that it's a bit impractical to use a state monad for this purpose but this is purely a learning exercise - pun most certainly intended.

That said, the performance is not that much worse: in order to calc the factorial of 100000 (the answer is ~21k digits long), the unfoldr version took ~1.2 sec (in GHCi) vs. ~1.5 sec for the state monad version.

import Control.Monad.State (State, get, put, evalState)
import Data.List (unfoldr)

fibonacci :: Integer -> Integer
fibonacci 0 = 0
fibonacci n = evalState fib_state (1,0,1,n)

fib_state :: State (Integer,Integer,Integer,Integer) Integer
fib_state = get >>=
            \s ->
              let (p1,p2,ctr,n) = s
              in case compare ctr n of
                   LT -> put (p1+p2, p1, ctr+1, n) >> fib_state
                   _  -> return p1

factorial :: Integer -> Integer
factorial n = evalState fact_state (n,1)

fact_state :: State (Integer,Integer) Integer
fact_state = get >>=
             \s -> 
               let (n,f) = s 
               in case n of
                      0 -> return f
                      _ -> put (n-1,f*n) >> fact_state

-------------------------------------------------------------------
--Functions below are used only to test output of functions above

factorial' :: Integer -> Integer
factorial' n = product [1..n]

fibonacci' :: Int -> Integer
fibonacci' 0 = 1
fibonacci' 1 = 1
fibonacci' n =  
  let getFst (a,b,c) = a
  in  getFst
    $ last 
    $ unfoldr (\(p1,p2,cnt) -> 
               if cnt == n
                  then Nothing
                  else Just ((p1,p2,cnt)
                            ,(p1+p2,p1,cnt+1))
              ) (1,1,1) 

Solution

  • Your functions seem to be a bit more complicated than they need to be, but you have the right idea. For the factorial, all you need to keep track of is the current number you're multiplying by and the number that you've accumulated so far. So, we'll say that State Int Int is a computation that operates on the current number on the state and returns the number that you've multiplied up until now:

    fact_state :: State Int Int
    fact_state = get >>= \x -> if x <= 1
                               then return 1
                               else (put (x - 1) >> fmap (*x) fact_state)
    
    factorial :: Int -> Int
    factorial = evalState fact_state
    
    Prelude Control.Monad.State.Strict Control.Applicative> factorial <$> [1..10]
    [1,2,6,24,120,720,5040,40320,362880,3628800]
    

    The fibonacci sequence is similar. You need to keep the last two numbers in order to know what you're going to be adding together, and how far you've gone so far:

    fibs_state :: State (Int, Int, Int) Int
    fibs_state = get >>= \(x1, x2, n) -> if n == 0
                                         then return x1
                                         else (put (x2, x1+x2, n-1) >> fibs_state)
    
    fibonacci n = evalState fibs_state (0, 1, n)
    
    Prelude Control.Monad.State.Strict Control.Applicative> fibonacci <$> [1..10]
    [1, 1, 2, 3, 5, 8, 13, 21, 34, 55]