Search code examples
haskellmemoization

Memoizing a function of type [Integer] -> a


My problem is how to efficiently memoize an expensive function f :: [Integer] -> a that is defined for all finite lists of integers and has the property f . sort = f?

My typical use case is that given a list as of integers I need to obtain the values f (a:as) for various Integer a, so I'd like to build up simultaneously a directed labelled graph whose vertices are pairs of an Integer list and its function value. An edge labelled by a from (as, f as) to (bs, f bs) exists if and only if a:as = bs.

Stealing from a brilliant answer by Edward Kmett I simply copied

{-# LANGUAGE BangPatterns #-}
data Tree a = Tree (Tree a) a (Tree a)
instance Functor Tree where
  fmap f (Tree l m r) = Tree (fmap f l) (f m) (fmap f r)

index :: Tree a -> Integer -> a
index (Tree _ m _) 0 = m
index (Tree l _ r) n = case (n - 1) `divMod` 2 of
  (q,0) -> index l q
  (q,1) -> index r q

nats :: Tree Integer
nats = go 0 1
  where go !n !s = Tree (go l s') n (go r s')
          where l = n + s
                r = l + s
                s' = s * 2

and adapted his idea to my problem as

-- directed graph labelled by Integers
data Graph a = Graph a (Tree (Graph a))
instance Functor Graph where
  fmap f (Graph a t) = Graph (f a) (fmap (fmap f) t)

-- walk the graph following the given labels
walk :: Graph a -> [Integer] -> a
walk (Graph a _) [] = a
walk (Graph _ t) (x:xs) = walk (index t x) xs

-- graph of all finite integer sequences
intSeq :: Graph [Integer]
intSeq = Graph [] (fmap (\n -> fmap (n:) intSeq) nats)

-- could be replaced by Data.Strict.Pair
data StrictPair a b = StrictPair !a !b
  deriving Show

-- f = sum modified according to Edward's idea (the real function is more complicated)
g :: ([Integer] -> StrictPair Integer [Integer]) -> [Integer] -> StrictPair Integer [Integer]
g mf [] = StrictPair 0 []
g mf (a:as) = StrictPair (a+x) (a:as)
  where StrictPair x y = mf as

g_graph :: Graph (StrictPair Integer [Integer])
g_graph = fmap (g g_m) intSeq

g_m :: [Integer] -> StrictPair Integer [Integer]
g_m = walk g_graph

This works OK, but as the function f is independent of the order of the occurring integers (but not of their counts) there should be only one vertex in the graph for all integer lists equal up to ordering.

How do I achieve this?


Solution

  • Reading the functional pearl Trouble Shared is Trouble Halved by Richard Bird and Ralf Hinze, I understood how to implement, what I was looking for two years ago (again based on Edward Kmett's trick):

    {-# LANGUAGE BangPatterns #-}
    import Data.Function (fix)
    
    data Tree a = Tree (Tree a) a (Tree a)
      deriving Show
    
    instance Functor Tree where
      fmap f (Tree l m r) = Tree (fmap f l) (f m) (fmap f r)
    
    index :: Tree a -> Integer -> a
    index (Tree _ m _) 0 = m
    index (Tree l _ r) n = case (n - 1) `divMod` 2 of
      (q,0) -> index l q
      (q,1) -> index r q
    
    nats :: Tree Integer
    nats = go 0 1
      where go !n !s = Tree (go l s') n (go r s')
              where l = n + s
                    r = l + s
                    s' = s * 2
    
    data IntSeqTree a = IntSeqTree a (Tree (IntSeqTree a))
    
    val :: IntSeqTree a -> a
    val (IntSeqTree a _) = a
    
    step :: Integer -> IntSeqTree t -> IntSeqTree t
    step n (IntSeqTree _ ts) = index ts n
    
    intSeqTree :: IntSeqTree [Integer]
    intSeqTree = fix $ create []
      where create p x = IntSeqTree p $ fmap (extend x) nats
            extend x n = case span (>n) (val x) of
                           ([], p) -> fix $ create (n:p)
                           (m, p)  -> foldr step intSeqTree (m ++ n:p)
    
    instance Functor IntSeqTree where
      fmap f (IntSeqTree a t) = IntSeqTree (f a) (fmap (fmap f) t)
    

    In my use case I have hundreds or thousands of similar integer sequences (of length few hundred entries) that are generated incrementally. So for me this way is cheaper than sorting the sequences before looking up the function value (which I will access by using fmap on intSeqTree).