Search code examples
haskellfunctional-programmingstateimmutability

How can I improve the API of a pure function that returns state in Haskell?


I am attempting to write a function which given some function f memoizes f in such a way that after defining g = memoize f followed by g x all subsequent invocations of g with the argument x simply returns the cached result.

However, I am struggling to come up with an implementation that improves upon the explicit state passing that's required with the following:

memoize :: Ord t => (t -> a) -> Map t a -> t -> Map t a
memoize f m a = case Map.lookup a m of
    Just _  -> m
    Nothing -> Map.insert a (f a) m

with a contrived example to show its usage:

main :: IO ()
main = do
  let memoPlusOne = memoize (+ 1) in
    let m = memoPlusOne Map.empty 1
      in let mm = memoPlusOne m 1
        in print mm

I am aware that there are other, better ways to memoize functions in Haskell, but my question is more concerned with improving on the general pattern of passing state to a function to avoid any state mutations that would otherwise be encapsulated as in other languages, e.g. as in this example in Ocaml:

let memo_rec f =
  let h = Hashtbl.create 16 in
  let rec g x =
    try Hashtbl.find h x
    with Not_found ->
      let y = f g x in
      (* update h in place *)
      Hashtbl.add h x y;
      y
  in
  g

Solution

  • my question is more concerned with improving on the general pattern of passing state to a function to avoid any state mutations that would otherwise be encapsulated as in other languages

    There are lots of ways to do fun things with immutable state in Haskell! I can give some examples, but I also feel obligated to point out that the most efficient and user-friendly version of memoize will likely use unsafe under the hood, and if that's what you want, you're probably better off using an existing library than messing with it yourself. But, if you're experimenting, then have at it!

    That said, before we begin with new tricks, let's take a look at what you've created it. The biggest problem with your current code is that the type is off. You have memoize f m :: t -> Map t a, which doesn't even produce the expected result of f t as expected. After all, the theoretically best type signature for memoize is (t -> a) -> t -> a.

    You can fix this by changing memoize to:

    memoize :: Ord t => (t -> a) -> Map t a -> t -> (Map t a, a)
    memoize f m t = case Map.lookup t m of
        Just a  -> (m, a)
        Nothing -> let a = f t in (Map.insert t a m, a)
    

    With this, you compute the new memoized state but also return the result, which is ultimately what memoizing is for anyway. This may seem like an irrelevant change (can't you just extract the right a from the Map t a anyway, you might be asking?) but it's useful to use this type signature when exploring how to handle state.


    Now, to get back to your question: how can we improve on this general pattern of passing state? You may notice that your function takes a state and returns a new state, and this is what the State monad is all about. Indeed, State is simply defined as:

    newtype State s a = State {runState :: s -> (s, a)}
    

    (In the transformers package, where you might import it, the type is actually a bit different, but it's isomorphic to this.) So, you could rewrite memoize to be in the State monad like so:

    memoize :: Ord t => (t -> a) -> t -> State (Map t a) a
    memoize f t = do
        m <- get
        case Map.lookup t m of
            Just a  -> pure a
            Nothing -> do
                let a = f t
                put (Map.insert t a m)
                pure a
    

    Annoyingly, you can only use this while you're in the monad. For instance:

    main :: IO ()
    main = do
      let memoPlusOne = memoize (+ 1)
      flip runState Map.empty $ do
        res1 <- memoPlusOne 1
        res2 <- memoPlusOne 3
        print [res1, res2]
    

    You can also use evalState instead of runState if you're interested in the memo-table when you're done.


    Instead of hiding the state in a monad, we can hide the state directly in the function. That is, instead of returning a new state, let's return a new function:

    newtype Memoized t a = Memo { runMemo :: t -> (Memoized t a, a) }
    
    memoize :: Ord t => (t -> a) -> Memoized t a
    memoize f = Memo $ go Map.empty
      where
        go m t = case Map.lookup t m of
          Just a  -> (Memo $ go m, a)
          Nothing -> let a = f t in (Memo $ go $ Map.insert t a m, a)
    

    This trick bundles up the state in a new Memoized object every time you call the memoized function. So, as long as you always use the new Memoized object every time you make it, you're sure to always be memoizing. Consider this version of main:

    main :: IO ()
    main = do
      let memoPlusOne = memoize (+ 1)
      let (memoPlusOne', res1) = runMemo memoPlusOne 1
      let (memoPlusOne'', res2) = runMemo memoPlusOne' 3
      print [res1, res2]