Search code examples
haskellghcackermann

Memoization of Ackermann function


I would like to compute the A(3, 20) value of Ackermann function (see Wikipedia) which should be 2^23 - 3 = 8388605 using Data.MemoCombinators. My code is:

{-# LANGUAGE BangPatterns #-}
import      Data.MemoCombinators as Memo

ack = Memo.memo2 Memo.integral Memo.integral ack'
    where
        ack' 0 !n = n+1
        ack' !m 0 = ack (m-1) 1
        ack' !m !n = ack (m-1) $! (ack m (n-1))

main = print $ ack 3 20

But it ends up on stack overflow error ;-) Can it be tuned or the computation chain is really that long and even memoization cannot help?


Solution

  • One of the points of the Ackermann function is that computing it recursively leads to a very deep recursion.

    The recursion depth is about equal to the result (depending on how you count, it's a few levels more or less) without meoisation. Unfortunately, memoisation doesn't buy you much if you fill the memo table according to the call-tree.

    Let's follow the computation of ack 3 2:

    ack 3 2
    ack 2 $ ack 3 1
    ack 2 $ ack 2 $ ack 3 0
    ack 2 $ ack 2 $ ack 2 1
    ack 2 $ ack 2 $ ack 1 $ ack 2 0
    ack 2 $ ack 2 $ ack 1 $ ack 1 1
    ack 2 $ ack 2 $ ack 1 $ ack 0 $ ack 1 0
    ack 2 $ ack 2 $ ack 1 $ ack 0 $ ack 0 1    -- here's the first value we can compute and put in the map
    ack 2 $ ack 2 $ ack 1 $ ack 0 2            -- next three, (0,2) -> 3, (1,1)->3 and (2,0)->3
    ack 2 $ ack 2 $ ack 1 3                    -- need to unfold that
    ack 2 $ ack 2 $ ack 0 $ ack 1 2
    ack 2 $ ack 2 $ ack 0 $ ack 0 $ ack 1 1    -- we know that, it's 3
    ack 2 $ ack 2 $ ack 0 $ ack 0 3            -- okay, easy (0,3)->4, (1,2)->4
    ack 2 $ ack 2 $ ack 0 4                    -- (0,4)->5, (1,3)->5, (2,1)->5
    ack 2 $ ack 2 5                            -- unfold
    ack 2 $ ack 1 $ ack 2 4
    ack 2 $ ack 1 $ ack 1 $ ack 2 3
    ack 2 $ ack 1 $ ack 1 $ ack 1 $ ack 2 2
    ack 2 $ ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 2 1
    ack 2 $ ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 2 0  -- we know that one, 3
    ack 2 $ ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 1 3          -- that one too, it's 5
    ack 2 $ ack 1 $ ack 1 $ ack 1 $ ack 1 5                  -- but not that
    ack 2 $ ack 1 $ ack 1 $ ack 1 $ ack 0 $ ack 1 4
    ack 2 $ ack 1 $ ack 1 $ ack 1 $ ack 0 $ ack 0 $ ack 1 3  -- look up
    ack 2 $ ack 1 $ ack 1 $ ack 1 $ ack 0 $ ack 0 5          -- easy (0,5)->6
    ack 2 $ ack 1 $ ack 1 $ ack 1 $ ack 0 6                  -- now (1,5)->7 is known too, and (2,2)->7
    ack 2 $ ack 1 $ ack 1 $ ack 1 7
    ack 2 $ ack 1 $ ack 1 $ ack 0 $ ack 1 6
    ack 2 $ ack 1 $ ack 1 $ ack 0 $ ack 0 $ ack 1 5
    ack 2 $ ack 1 $ ack 1 $ ack 0 $ ack 0 7                  -- here (1,6)->8 becomes known
    ack 2 $ ack 1 $ ack 1 $ ack 0 8                          -- and here (1,7)->9, (2,3)->9
    ack 2 $ ack 1 $ ack 1 9
    ack 2 $ ack 1 $ ack 0 $ ack 1 8
    ack 2 $ ack 1 $ ack 0 $ ack 0 $ ack 1 7                  -- known
    ack 2 $ ack 1 $ ack 0 $ ack 0 9                          -- here we can add (1,8)->10
    ack 2 $ ack 1 $ ack 0 10                                 -- and (1,9)->11, (2,4)->11
    ack 2 $ ack 1 11
    ack 2 $ ack 0 $ ack 1 10
    ack 2 $ ack 0 $ ack 0 $ ack 1 9                          -- known
    ack 2 $ ack 0 $ ack 0 11                                 -- (1,10)->12
    ack 2 $ ack 0 12                                         -- (1,11)->13, (2,5)->13
    ack 2 13
    ack 1 $ ack 2 12
    ack 1 $ ack 1 $ ack 2 11
    ack 1 $ ack 1 $ ack 1 $ ack 2 10
    ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 2 9
    ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 2 8
    ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 2 7
    ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 2 6
    ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 2 5 -- uff
    ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 1 13
    ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 0 $ ack 1 12
    ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 0 $ ack 0 $ ack 1 11 -- uff
    ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 0 $ ack 0 13         -- (1,12)->14
    ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 0 14          -- (1,13)->15, (2,6)->15
    ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 1 15
    ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 0 $ ack 1 14
    ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 0 $ ack 0 $ ack 1 13
    ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 0 $ ack 0 15          -- (1,14)->16
    ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 0 16                  -- (1,15)->17, (2,7)->17
    ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 1 17
    ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 0 $ ack 1 16
    ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 0 $ ack 0 $ ack 1 15
    ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 0 $ ack 0 17                  -- (1,16)->18
    ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 0 18                          -- (1,17)->19, (2,8)->19
    ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 1 19
    ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 0 $ ack 1 18
    ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 0 $ ack 0 $ ack 1 17
    ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 0 $ ack 0 19                          -- (1,18)->20
    ack 1 $ ack 1 $ ack 1 $ ack 1 $ ack 0 20                                  -- (1,19)->21, (2,9)->21
    ack 1 $ ack 1 $ ack 1 $ ack 1 21
    ack 1 $ ack 1 $ ack 1 $ ack 0 $ ack 1 20
    ack 1 $ ack 1 $ ack 1 $ ack 0 $ ack 0 $ ack 1 19                          -- known
    ack 1 $ ack 1 $ ack 1 $ ack 0 $ ack 0 21                                  -- (1,20)->22
    ack 1 $ ack 1 $ ack 1 $ ack 0 22                                          -- (1,21)->23, (2,10)->23
    ack 1 $ ack 1 $ ack 1 23
    ack 1 $ ack 1 $ ack 0 $ ack 1 22
    ack 1 $ ack 1 $ ack 0 $ ack 0 $ ack 1 21                                  -- known
    ack 1 $ ack 1 $ ack 0 $ ack 0 23                                          -- (1,22)->24
    ack 1 $ ack 1 $ ack 0 24                                                  -- (1,23)->25, (2,11)->25
    ack 1 $ ack 1 25
    ack 1 $ ack 0 $ ack 1 24
    ack 1 $ ack 0 $ ack 0 $ ack 1 23                                          -- known
    ack 1 $ ack 0 $ ack 0 25                                                  -- (1,24)->26
    ack 1 $ ack 0 26                                                          -- (1,25)->27, (2,12)-> 27
    ack 1 27
    ack 0 $ ack 1 26
    ack 0 $ ack 0 $ ack 1 25
    ack 0 $ ack 0 27
    ack 0 28
    29
    

    So when you need to calculate a new (not-yet-known) ack 1 n, you need to compute two new ack 0 n, and when you need a new ack 2 n, you need two new ack 1 n, and hence 4 new ack 0 n, that's all not too dramatic.

    But when you need a new ack 3 n, you need ack 3 (n-1) - ack 3 (n-2) new ack 2 k. All things told, after you computed ack 3 k, you need to compute 2^(k+2) new values of ack 2 n, and by the calling structure, these are nested calls, so you get a stack of 2^(k+2) nested thunks.

    To avoid that nesting, you need to restructure the computation, e.g. by forcing the new needed ack (m-1) k in increasing order of k,

        ack' m 1 = ack (m-1) $! ack (m-1) 1
        ack' m n = foldl1' max [ack (m-1) k | k <- [ack m (n-2) .. ack m (n-1)]]
    

    which allows the computation to run (slowly) with a small stack (but it needs a terrible lot of heap still, a tailor-made memoisation strategy seems called for).

    Storing only ack m n for m >= 2, and evaluating ack 1 n as if it were memoised reduces the necessary memory far enough that computing ack 3 20 finishes using less than 1GB of heap (using Int instead of Integer makes it run about twice as fast):

    {-# LANGUAGE BangPatterns #-}
    module Main (main) where
    
    import qualified Data.Map as M
    import Control.Monad.State.Strict
    import Control.Monad
    
    type Table = M.Map (Integer,Integer) Integer
    
    ack :: Integer -> Integer -> State Table Integer
    ack 0 n = return (n+1)
    ack 1 n = return (n+2)
    ack m 0 = ack (m-1) 1
    ack m 1 = do
        !n <- ack (m-1) 1
        ack (m-1) n
    ack m n = do
        mb <- gets (M.lookup (m,n))
        case mb of
          Just v -> return v
          Nothing -> do
              !s <- ack m (n-2)
              !t <- ack m (n-1)
              let foo a b = do
                    c <- ack (m-1) b
                    let d = max a c
                    return $! d
              !v <- foldM foo 0 [s .. t]
              mp <- get
              put $! M.insert (m,n) v mp
              return v
    
    main :: IO ()
    main = print $ evalState (ack 3 20) M.empty