Search code examples
haskellmemoization

Memoization in Haskell?


Any pointers on how to solve efficiently the following function in Haskell, for large numbers (n > 108)

f(n) = max(n, f(n/2) + f(n/3) + f(n/4))

I've seen examples of memoization in Haskell to solve fibonacci numbers, which involved computing (lazily) all the fibonacci numbers up to the required n. But in this case, for a given n, we only need to compute very few intermediate results.

Thanks


Solution

  • We can do this very efficiently by making a structure that we can index in sub-linear time.

    But first,

    {-# LANGUAGE BangPatterns #-}
    
    import Data.Function (fix)
    

    Let's define f, but make it use 'open recursion' rather than call itself directly.

    f :: (Int -> Int) -> Int -> Int
    f mf 0 = 0
    f mf n = max n $ mf (n `div` 2) +
                     mf (n `div` 3) +
                     mf (n `div` 4)
    

    You can get an unmemoized f by using fix f

    This will let you test that f does what you mean for small values of f by calling, for example: fix f 123 = 144

    We could memoize this by defining:

    f_list :: [Int]
    f_list = map (f faster_f) [0..]
    
    faster_f :: Int -> Int
    faster_f n = f_list !! n
    

    That performs passably well, and replaces what was going to take O(n^3) time with something that memoizes the intermediate results.

    But it still takes linear time just to index to find the memoized answer for mf. This means that results like:

    *Main Data.List> faster_f 123801
    248604
    

    are tolerable, but the result doesn't scale much better than that. We can do better!

    First, let's define an infinite tree:

    data Tree a = Tree (Tree a) a (Tree a)
    instance Functor Tree where
        fmap f (Tree l m r) = Tree (fmap f l) (f m) (fmap f r)
    

    And then we'll define a way to index into it, so we can find a node with index n in O(log n) time instead:

    index :: Tree a -> Int -> a
    index (Tree _ m _) 0 = m
    index (Tree l _ r) n = case (n - 1) `divMod` 2 of
        (q,0) -> index l q
        (q,1) -> index r q
    

    ... and we may find a tree full of natural numbers to be convenient so we don't have to fiddle around with those indices:

    nats :: Tree Int
    nats = go 0 1
        where
            go !n !s = Tree (go l s') n (go r s')
                where
                    l = n + s
                    r = l + s
                    s' = s * 2
    

    Since we can index, you can just convert a tree into a list:

    toList :: Tree a -> [a]
    toList as = map (index as) [0..]
    

    You can check the work so far by verifying that toList nats gives you [0..]

    Now,

    f_tree :: Tree Int
    f_tree = fmap (f fastest_f) nats
    
    fastest_f :: Int -> Int
    fastest_f = index f_tree
    

    works just like with list above, but instead of taking linear time to find each node, can chase it down in logarithmic time.

    The result is considerably faster:

    *Main> fastest_f 12380192300
    67652175206
    
    *Main> fastest_f 12793129379123
    120695231674999
    

    In fact it is so much faster that you can go through and replace Int with Integer above and get ridiculously large answers almost instantaneously

    *Main> fastest_f' 1230891823091823018203123
    93721573993600178112200489
    
    *Main> fastest_f' 12308918230918230182031231231293810923
    11097012733777002208302545289166620866358
    

    For an out-of-the-box library that implements the tree based memoization, use MemoTrie:

    $ stack repl --package MemoTrie
    
    Prelude> import Data.MemoTrie
    Prelude Data.MemoTrie> :set -XLambdaCase
    Prelude Data.MemoTrie> :{
    Prelude Data.MemoTrie| fastest_f' :: Integer -> Integer
    Prelude Data.MemoTrie| fastest_f' = memo $ \case
    Prelude Data.MemoTrie|   0 -> 0
    Prelude Data.MemoTrie|   n -> max n (fastest_f'(n `div` 2) + fastest_f'(n `div` 3) + fastest_f'(n `div` 4))
    Prelude Data.MemoTrie| :}
    Prelude Data.MemoTrie> fastest_f' 12308918230918230182031231231293810923
    11097012733777002208302545289166620866358