Search code examples
haskellmemoization

Haskell: memoization of recursion


if I have the following function:

go xxs t i
  | t == 0         = 1
  | t < 0          = 0
  | i < 0          = 0
  | t < (xxs !! i) = go xxs t (i-1)
  | otherwise      = go xxs (t - (xxs !! i)) (i-1) + go xxs t (i-1)

what is the best way to memoize the results? I can't seem to get my head how to store a dynamic set of tuples and update and return the value at the same time.

The equivalent of what I am trying to do in python would be:

def go(xxs, t , i, m):
  k = (t,i)
  if  k in m:      # check if value for this pair is already in dictionary 
      return m[k]
  if t == 0:
      return 1
  elif t < 0:
      return 0
  elif i < 0:
      return 0
  elif t < xxs[i]:
      val = go(xxs, t, i-1,m)  
  else:
      val = (go(xxs, total - xxs[i]), i-1,m) + go(xxs, t, i-1,m)
  m[k] = val  # store the new value in dictionary before returning it
  return val

EDIT: I think this is somewhat different to this answer. The function in question there has a linear progression and you can index the results with a list [1..]. In this case, my Keys (t,i) are not necessarily in order or incremental. for example I could end up with an set of keys that are

[(9,1),(8,2),(7,4),(6,4),(5,5),(4,6),(3,6),(2,7),(1,8),(0,10)]


Solution

  • is there no easier way to roll your own [memoization?]

    Easier than what? A state monad is really easy and if you are used to thinking imperatively then it should also be intuitive.

    The full, inlined, version that uses a vector instead of the list is:

    {-# LANGUAGE MultiWayIf #-}
    import Control.Monad.Trans.State as S
    import Data.Vector as V
    import Data.Map.Strict as M
    
    goGood :: [Int] -> Int -> Int -> Int
    goGood xs t0 i0 =
        let v = V.fromList xs
        in evalState (explicitMemo v t0 i0) mempty
     where
     explicitMemo :: Vector Int -> Int -> Int -> State (Map (Int,Int) Int) Int
     explicitMemo v t i = do
        m <- M.lookup (t,i) <$> get
        case m of
            Nothing ->
             do res <- if | t == 0          -> pure 1
                          | t < 0           -> pure 0
                          | i < 0           -> pure 0
                          | t < (v V.! i)   -> explicitMemo v t (i-1)
                          | otherwise       -> (+) <$> explicitMemo v (t - (v V.! i)) (i-1) <*> explicitMemo v t
     (i-1)
                S.modify (M.insert (t,i) res)
                pure res
            Just r  -> pure r
    

    That is, we look up in a map if we've already computed the result. If so, return the result. If not, compute and store the result before returning it.

    We can clean this up a lot with just a couple helper functions:

    prettyMemo :: Vector Int -> Int -> Int -> State (Map (Int,Int) Int) Int
    prettyMemo v t i = cachedReturn =<< cachedEval (
                if | t == 0          -> pure 1
                   | t < 0           -> pure 0
                   | i < 0           -> pure 0
                   | t < (v V.! i)   -> prettyMemo v t (i-1)
                   | otherwise       ->
                       (+) <$> prettyMemo v (t - (v V.! i)) (i-1)
                           <*> prettyMemo v t (i-1)
                )
     where
     key = (t,i)
     -- Lookup value in cache and return it
     cachedReturn res = S.modify (M.insert key res) >> pure res
    
     -- Use cached value or run the operation
     cachedEval oper = maybe oper pure =<< (M.lookup key <$> get)
    

    Now our map lookup and map update are in some simple (to the experienced Haskell developer) helper functions that wrap the entire computation. A small difference here is we update the map regardless of if the computation was cached at some minor computational cost.

    We can make this even cleaner by dropping the monad (see the linked related questions). There is a popular package (MemoTrie) that handles the guts for you:

    memoTrieVersion :: [Int] -> Int -> Int -> Int
    memoTrieVersion xs = go
     where
     v = V.fromList xs
     go t i | t == 0 = 1
            | t < 0  = 0
            | i < 0  = 0
            | t < v V.! i = memo2 go t (i-1)
            | otherwise   = memo2 go (t - (v V.! i)) (i-1) + memo2 go t (i-1)
    

    If you like the monadic style you could always use the monad-memo package.

    EDIT: A mostly-direct translation of your Python code to Haskell shows an important difference is the immutability of the variables. In your otherwise (or else) case you use go twice and implicitly one invocation will update the cache (m) that the second call uses, thus saving computation in a memoizing manner. In Haskell if you're avoiding monads and lazy evaluation to recursively define a vector (which can be quite powerful) then the simplest solution left is to explicitly pass your map (dictionary) around:

    import Data.Vector as V
    import Data.Map as M
    
    goWrapped :: Vector Int -> Int -> Int -> Int
    goWrapped xxs t i = fst $ goPythonVersion xxs t i mempty
    
    goPythonVersion :: Vector Int -> Int -> Int -> Map (Int,Int) Int -> (Int,Map (Int,Int) Int)
    goPythonVersion xxs t i m =
      let k = (t,i)
      in case M.lookup k m of -- if  k in m:
        Just r -> (r,m)       --     return m[k]
        Nothing ->
          let (res,m') | t == 0 = (1,m)
                       | t  < 0 = (0,m)
                       | i  < 0 = (0,m)
                       | t  < xxs V.! i = goPythonVersion xxs t (i-1) m
                       | otherwise  =
                          let (r1,m1) = goPythonVersion xxs (t - (xxs V.! i)) (i-1) m
                              (r2,m2) = goPythonVersion xxs t (i-1) m1
                          in (r1 + r2, m2)
          in (res, M.insert k res m')
    

    And while this version is a decent translation of the Python I'd rather see a more idiomatic solution such as the below. Notice we bind a variable to the resulting computation (named "computed" for the Int and the updated map) but thanks to lazy evaluation not much work is done unless the cache doesn't yield a result.

    {-# LANGUAGE ViewPatterns #-}
    {-# LANGUAGE TupleSections #-}
    goMoreIdiomatic:: Vector Int -> Int -> Int -> Map (Int,Int) Int -> (Int,Map (Int,Int) Int)
    goMoreIdiomatic xxs t i m =
      let cached = M.lookup (t,i) m
          ~(comp, M.insert (t,i) comp -> m')
            | t == 0 = (1,m)
            | t  < 0 = (0,m)
            | i  < 0 = (0,m)
            | t  < xxs V.! i = goPythonVersion xxs t (i-1) m
            | otherwise  =
               let (r1,m1) = goPythonVersion xxs (t - (xxs V.! i)) (i-1) m
                   (r2,m2) = goPythonVersion xxs t (i-1) m1
               in (r1 + r2, m2)
        in maybe (comp,m') (,m) cached