Search code examples
haskellmemoizationmutual-recursion

How to speed up (or memoize) a series of mutually recursive functions


I have a program which produces a series of functions f and g which looks like the following:

step (f,g) = (newF f g, newG f g)

newF f g x = r (f x) (g x)
newG f g x = s (f x) (g x)

foo = iterate step (f0,g0)

Where r and s are some uninteresting functions of f x and g x. I naively hoped that having foo be a list would mean that when I call the n'th f it will not recompute the (n-1)th f if it has already computed it (as would have happened if f and g weren't functions). Is there any way to memoize this without ripping the whole program apart (e.g. evaluating f0 and g0 on all relevant arguments and then working upward)?


Solution

  • You may find Data.MemoCombinators useful (in the data-memocombinators package).

    You don't say what argument types your f and g take --- if they both takes integral values then you would use it like this:

    import qualified Data.MemoCombinators as Memo
    
    foo = iterate step (Memo.integral f0, Memo.integral g0)
    

    If required, you could memoise the output of each step as well

    step (f,g) = (Memo.integral (newF f g), Memo.integral (newG f g))
    

    I hope you don't see this as ripping the whole program apart.


    In reply to your comment:

    This is the best I can come up with. It's untested, but should be working along the right lines.

    I worry that converting between Double and Rational is needlessly inefficient --- if there was a Bits instance for Double we could use Memo.bits instead. So this might not ultimately be of any practical use to you.

    import Control.Arrow ((&&&))
    import Data.Ratio (numerator, denominator, (%))
    
    memoV :: Memo.Memo a -> Memo.Memo (V a)
    memoV m f = \(V x y z) -> table x y z
      where g x y z = f (V x y z)
            table = Memo.memo3 m m m g
    
    memoRealFrac :: RealFrac a => Memo.Memo a
    memoRealFrac f = Memo.wrap (fromRational . uncurry (%))
                               ((numerator &&& denominator) . toRational)
                               Memo.integral
    

    A different approach.

    You have

    step :: (V Double -> V Double, V Double -> V Double)
         -> (V Double -> V Double, V Double -> V Double)
    

    How about you change that to

    step :: (V Double -> (V Double, V Double))
         -> (V Double -> (V Double, V Double))
    step h x = (r fx gx, s fx gx)
      where (fx, gx) = h x
    

    And also change

    foo = (fst . bar, snd . bar)
      where bar = iterate step (f0 &&& g0)
    

    Hopefully the shared fx and gx should result in a bit of a speed-up.