Search code examples
performancehaskelldynamic-programmingmemoizationknapsack-problem

Implementing Memoization efficiently on nonintegral keys


I am new to Haskell and have been practicing by doing some simple programming challenges. The last 2 days, I've been trying to implement the unbounded knapsack problem here. The algorithm I'm using is described on the wikipedia page, though for this problem the word 'weight' is replaced with the word 'length'. Anyways, I started by writing the code without memoization:

maxValue :: [(Int,Int)] -> Int -> Int
maxValue [] len = 0
maxValue ((l, val): other) len =
    if l > len then 
        skipValue
    else 
        max skipValue takeValue
    where skipValue = maxValue other len
          takeValue = (val + maxValue ([(l, val)] ++ other) (len - l)

I had hoped that haskell would be nice and have some nice syntax like #pragma memoize to help me, but looking around for examples, the solution was explained with this fibonacci problem code.

memoized_fib :: Int -> Integer
memoized_fib = (map fib [0 ..] !!)
   where fib 0 = 0
         fib 1 = 1
         fib n = memoized_fib (n-2) + memoized_fib (n-1)

After grasping the concept behind this example, I was very disappointed - the method used is super hacky and only works if 1) the input to the function is a single integer, and 2) the function needs to compute the values recursively in the order f(0), f(1), f(2), ... But what if my parameters are vectors or sets? And if I want to memoize a function like f(n) = f(n/2) + f(n/3), I need to compute the value of f(i) for all i less than n, when I don't need most of those values. (Others have pointed out this claim is false)

I tried implementing what I wanted by passing a memo table that we slowly fill up as an extra parameter:

maxValue :: (Map.Map (Int, Int) Int) -> [(Int,Int)] -> Int -> (Map.Map (Int, Int) Int, Int)
maxValue m [] len = (m, 0)
maxValue m ((l, val) : other) len =
    if l > len then
        (mapWithSkip, skipValue)
    else
        (mapUnion, max skipValue (takeValue+val))
    where (skipMap, skipValue) = maxValue m other len
          mapWithSkip = Map.insertWith' max (1 + length other, len) skipValue skipMap
          (takeMap, takeValue) = maxValue m ([(l, val)] ++ other) (len - l)
          mapWithTake = Map.insertWith' max (1 + length other, len) (takeValue+val) mapWithSkip
          mapUnion = Map.union mapWithSkip mapWithTake

But this is too slow, I believe because Map.union takes too long, it's O(n+m) rather than O(min(n,m)). Furthermore, this code seems a quite messy for something as simple as memoizaton. For this specific problem, you might be able to get away with generalizing the hacky approach to 2 dimensions, and computing a bit extra, but I want to know how to do memoization in a more general sense. How can I implement memoization in this more general form while maintaining the same complexity as the code would have in imperative languages?


Solution

  • My go-to way to do memoization in Haskell is usually MemoTrie. It's pretty straightforward, it's pure, and it usually does what I'm looking for.

    Without thinking too hard, you could produce:

    import Data.MemoTrie (memo2)
    maxValue :: [(Int,Int)] -> Int -> Int
    maxValue = memo2 go
      where
        go [] len = 0
        go lst@((l, val):other) len =
          if l > len then skipValue else max skipValue takeValue
          where
            skipValue = maxValue other len
            takeValue = val + maxValue lst (len - l)
    

    I don't have your inputs, so I don't know how fast this will go — it's a little strange to memoize the [(Int,Int)] input. I think you recognize this too because in your own attempt, you actually memoize over the length of the list, not the list itself. If you want to do that, it makes sense to convert your list to a constant-time-lookup array and then memoize. This is what I came up with:

    import qualified GHC.Arr as Arr
    
    maxValue :: [(Int,Int)] -> Int -> Int
    maxValue lst = memo2 go 0
      where
        values = Arr.listArray (0, length lst - 1) lst
        go i _ | i >= length lst = 0
        go i len = if l > len then skipValue else max skipValue takeValue
          where
            (l, val) = values Arr.! i
            skipValue = go (i+1) len
            takeValue = val + go i (len - l)