Search code examples
performancehaskellstatestate-monad

Haskell State monad vs state as parameter performance test


I start to learn a State Monad and one idea bother me. Instead of passing accumulator as parameter, we can wrap everything to the state monad.

So I wanted to compare performance between using State monad vs passing it as parameter.

So I created two functions:

sum1 :: Int -> [Int] -> Int
sum1 x [] = x
sum1 x (y:xs) =  sum1 (x + y) xs

and

sumState:: [Int] -> Int
sumState xs = execState (traverse f xs) 0
    where f n = modify (n+)

I compared them on the input array [1..1000000000].

  • sumState running time was around 15s
  • sum1 around 5s

We can see clear winner, but the I realised that sumState can be optimised as:

  1. We can use strict version of modify
  2. We do not need necessary the map list output, so we can use traverse_ instead

So the new optimised state function is:

sumState:: [Int] -> Int
sumState xs = execState (traverse_ f xs) 0
    where f n = modify' (n+)

which has running time around 350ms. This is a huge improvement. It was shocking.

Why the modified sumState has better performance then sum1? Can sum1 be optimised to match or even be better then sumState?

I also tried other different implementation of sum as

  • using built in sum function, which gives me around 240ms ((sum [1..x] ::Int))
  • using strict foldl', which gives me the same result around 240ms (with implicit [Int] -> Int)

Does it actually mean that it is better to use foldl function or State monad to pass accumulator instead of passing it as argument to the function?

Thank you for help.

EDIT:

Each function was in separate file with own main function and compiled with "-O2" flag.

main = do
    x <- (read . head ) <$> getArgs
    print $ <particular sum function> [1..x]

Runtime was measured via time command on linux.


Solution

  • To give a bit more explanation as to why traverse is slower: traverse f xs has has type State [()] and that [()] (list of unit tuples) is built up during the summation. This prevents further optimizations and would cause a memory leak if you were not using lazy state.

    Update: I think GHC should have been able to notice that that list of unit tuples is never used, so I opened a GHC issue.

    In both cases, To get the best performance we want to combine (or fuse) the summation with the enumeration [1..x] into a tight recursive loop which simply increments and adds until it reaches x. The resulting code would look something like this:

    sumFromTo :: Int -> Int -> Int -> Int
    sumFromTo s x y
      | x == y = s + x
      | otherwise = sumFromTo (s + x) (x + 1) y
    

    This avoids allocations for the list [1..x].

    The base library achieves this optimization using foldr/build fusion, also known as short cut fusion. The sum, foldl' and traverse (for lists) functions are implemented using the foldr function and [1..x] is implemented using the build function. The foldr and build function have special optimization rules so that they can be fused. Your custom sum1 function doesn't use foldr and so it can never be fused with [1..x] in this way.