Search code examples
performancehaskellvectorstate-monad

Iterate State Monad and Collect Results in Sequence with Good Performance


I implemented the following function:

iterateState :: Int -> (a -> State s a) -> (a -> State s [a])
iterateState 0 f a = return []
iterateState n f a = do
    b <- f a
    xs <- iterateState (n - 1) f b
    return $ b : xs

My primary use case is for a = Double. It works, but it is very slow. It allocates 528MB of heap space to produce a list of 1M Double values and spends most of its time doing garbage collection.

I have experimented with implementations that work on the type s -> (a, s) directly as well as with various strictness annotations. I was able to reduce the heap allocation somewhat, but not even close to what one would expect from a reasonable implementation. I suspect that the resulting ([a], s) being a combination of something to be consumed lazily ([a]) and something whose WHNF forces the entire computation (s) makes optimization difficult for GHC.

Assuming that the iterative nature of lists would be unsuitable for this situation, I turned to the vector package. To my delight, it already contains

iterateNM :: (Monad m, Unbox a) => Int -> (a -> m a) -> a -> m (Vector a)

Unfortunately, this is only slightly faster than my list implementation, still allocating 328MB of heap space. I assumed that this is because it uses unstreamM, whose description reads

Load monadic stream bundle into a newly allocated vector. This function goes through a list, so prefer using unstream, unless you need to be in a monad.

Looking at its behavior for the list monad, it is understandable that there is no efficient implementation for general monads. Luckily, I only need the state monad, and I found another function that almost fits the signature of the state monad.

unfoldrExactN :: Unbox a => Int -> (b -> (a, b)) -> b -> Vector a

This function is blazingly fast and performs no excess heap allocation beyond the 8MB needed to hold the resulting unboxed vector of 1M Double values. Unfortunately, it does not return the final state at the end of the computation, so it cannot be wrapped in the State type.

I looked at the implementation of unfoldrExactN to see if I could adjust it to expose the final state at the end of the computation. Unfortunately, this seems to be difficult, as the stream constructed by

unfoldrExactN :: Monad m => Int -> (s -> (a, s)) -> s -> Stream m a

which is eventually expanded into a vector by unstream has already forgotten the state type s.

I imagine I could circumvent the entire Stream infrastructure and implement iterateState directly on mutable vectors in the ST monad (similarly to how unstream expands a stream into a vector). However, I would lose all the benefits of stream fusion, as well as turning a computation that is easily expressed as a pure function into imperative low-level mush just for performance reasons. This is particularly frustrating while knowing that the existing unfoldrExactN already calculates all the values I want, but I have no access to them.

Is there a better way?

Can this function be implemented in a purely functional way with reasonable performance and no excess heap allocations? Preferably in a way that ties into the vector package and its stream fusion infrastructure.


Solution

  • The following program has 12MB max residency on my computer when compiled with optimizations:

    import Data.Vector.Unboxed
    import Data.Vector.Unboxed.Mutable
    
    iterateNState :: Unbox a => Int -> (a -> s -> (s, a)) -> (a -> s -> (s, Vector a))
    iterateNState n f a0 s0 = createT (unsafeNew n >>= go 0 a0 s0) where
        go i a s arr
            | i >= n = pure (s, arr)
            | otherwise = do
                unsafeWrite arr i a
                case f a s of
                    (s', a') -> go (i+1) a' s' arr
    
    main = id
        . print
        . Data.Vector.Unboxed.sum
        . snd
        $ iterateNState 1000000 (\a s -> (s+1, a+s :: Int)) 0 0
    

    (It continues to have a nice low residency even when the final two 0s are read from input dynamically.)