Search code examples
haskellcachingghc

Is there a way to tell Haskell runtime to use memoization for functions with approximately the same inputs?


I have a Haskell function that takes several floats as inputs. This function is somewhat expensive to calculate, so it would be really nice if I could tell Haskell, "If all inputs are within +/- 1 ppm, use the cached function."

Is there a way to do this?

I suppose I could write a wrapper that does something like this to all inputs:

1E-6 * floor(x * 1E6)

which increases the likelihood that Haskell will treat two almost identical floats as being the same number, however you sometimes run into issues when an algorithm relies on perfect equivalence of two floating point numbers. I always try to avoid that.

So before I do this, is there a way to tell GHC's memoization mechanism to tolerate a predetermined amount of slop?


Solution

  • Expanding on Daniel's comment, GHC does not have any sort of caching or memoization built in, at all.

    This is easy to see:

    -- foo.hs
    import Debug.Trace (trace)
    
    f :: () -> Char
    f () = trace "f called" '7'
    
    main :: IO ()
    main = do
      print $ f ()
      print $ f ()
      print $ f ()
    

    Now if you runhaskell foo.hs, you should see this:

    f called
    '7'
    f called
    '7'
    f called
    '7'
    

    Despite the fact that f's input type () has literally only a single possible non-bottom value, GHC hasn't cached f () as the return value 7; it's evaluated the call f () three times, triggering the trace1 each time.

    People sometimes talk about lazy evaluation as involving some sort of memoization, but I honestly don't know why. Sure, we can do this:

    -- bar.hs
    import Debug.Trace (trace)
    
    f :: () -> Char
    f () = trace "f called" '7'
    
    main :: IO ()
    main = do
      putStrLn "before r is bound"
      let r = f ()
      putStrLn "after r is bound"
      print r
      print r
      print r
    
    

    And runhaskell bar.hs will print this:

    ❯ runhaskell bar.hs
    before r is defined
    after r is defined
    f called
    '7'
    '7'
    '7'
    

    Here we can see that r is only evaluated once, and the evaluation didn't happen at the point in main where r = f () was defined, but only later when r was first used in print r. So it seems GHC has "cached" the value of r across its 3 usages.

    But that is completely unsurprising! If we bind a variable in just about any mainstream programming language and then use the variable multiple times, we expect that the expression bound to the variable will be evaluated at most once. This isn't "caching", it's just ordinary variable binding. Frequently it's the entire point of binding a variable. Laziness in Haskell means the variable can be bound without actually evaluating the RHS; the system can wait to see when (if ever) the first usage of that variable occurs. But that has nothing to do with memoization; at no point is a function call first checking its argument to see if a prior call has already been evaluated on that argument.


    1 If you're not aware, the Debug.Trace module has a variety of functions intended to help do "print-debugging" without having to rewrite all the coded you're trying to debug to be in IO, which is normally required to print to the console.

    trace simply prints its first argument to the console and then returns its second argument, but its type lies to the compiler so that it looks like a pure function. Thus having a trace call inside f lets us see when the runtime system evaluates a call. Because the compiler is unaware of the side-effect of trace printing to the console, this wouldn't stop it from caching a call to f or applying any other optimisation changing the number of times f is called.


    Now, going back to foo.hs it is true that GHC isn't actually guaranteed to print f called 3 times. If we compile foo.hs with optimization and then run it, we see something different:

    ❯ ghc -O foo.hs
    [1 of 1] Compiling Main             ( foo.hs, foo.o )
    Linking foo ...
    
    ❯ ./foo
    f called
    '7'
    '7'
    '7'
    

    Now f () has only been evaluated once! But this still isn't caching or memoization. It's just a compile-time optimization; GHC noticed that the same expression occurs multiple times in the source code, and it thinks they're pure expressions (because trace lies about its type), so it transformed the code to code that only evaluates the expression a single time. This is just a common-subexpression-elimination optimisation, the same as might be done by an optimising C compiler. It has nothing to do with memoization; there is nothing at runtime "remembering" the result of a call and checking later calls to see if they have the same argument. It's based entirely on the same expression occurring in the source code (rather than at runtime).


    So, with no memoization or caching mechanism built into GHC at all, obviously there's nothing to adjust to make it memoize float arguments fuzzily. If you want memoization in Haskell you need to write code that does it, or use a memoization library of which there are a number on Hackage. I don't know whether any of them allow you to provide a custom function for deciding when a call matches a previously recorded one (which is what you would need to customise your float-functions to use approximate equality).