Search code examples
haskelltypeclass

Caching an expensive to compute result in a class instance


Lets say I've got something like the following

class C a where
  f :: a -> Text

newtype T a = T a

expensiveFunction :: (Bounded a, Enum a, ToText a) => Map a Text
expensiveFunction = _ 

cheapFunction :: (Bounded a, Enum a, ToText a) => T a -> Map a Text -> Text
cheapFunction = _

instance (Bounded a, Enum a, ToText a) => C (T a) where
  f x = cheapFunction x expensiveFunction

What I would like to do is this:

newtype U = U Int
  deriving C via (T Int)

Here's the problem I see. expensiveFunction will be called on each call to f. Which is bad.

I did have this thought, of changing the T a instance to:

instance (Bounded a, Enum a, ToText a) => C (T a) where
  f = flip cheapFunction expensiveFunction

And then hoping that f (i.e. the function of one argument) is computed once, which hopefully means that expensiveFunction is only computed once.

But I'm not sure how reliable this is. I'm guessing that probably won't work if the instance requested is polymorphic, but if the instance is monomorphic (i.e. T Int above) then I was hoping this would do the trick.

But f isn't an ordinary function, it's a instance function, and since f is defined in the class C as a function, not a stand alone value, perhaps it will be treated like a function and recompute each time.

Any thoughts? Is there a better approach I should be using here. I'd like to keep the nice deriving via syntax if possible.


Solution

  • I did some quick experiments. Your suggestion of using flip to move expensiveFunction outside the final application of f seems to work quite well.

    Here's what I used:

    module Main where
    
    import C
    import U
    
    main :: IO ()
    main = do
      putStrLn "f @(T Int)"
      let a = f (T @Int 1)
          b = f (T @Int 2)
          c = f (T @Int 3)
      a `seq` b `seq` c `seq` print [a, b, c]
    
      putStrLn "\n\nU"
      let x = f (U 4)
          y = f (U 5)
          z = f (U 6)
      x `seq` y `seq` z `seq` print [x, y, z]
    
      putStrLn "\n\ng @(T Int)"
      let g = f @(T Int)
          i = g (T 7)
          j = g (T 8)
          k = g (T 9)
      i `seq` j `seq` k `seq` print [i, j, k]
    
    
    module C
      ( C (..)
      , T (..)
      , ToText (..)
      , expensiveFunction
      , cheapFunction
      )
    where
    
    
    import Data.Map (Map)
    import Data.Text (Text)
    import Data.Text qualified as Text
    
    import Debug.Trace (trace)
    
    
    class ToText a where
      toText :: a -> Text
      default toText :: Show a => a -> Text
      toText = Text.pack . show
    
    
    instance ToText Int
    
    
    class C a where
      f :: a -> Text
    
    newtype T a = T a
    
    expensiveFunction :: (Bounded a, Enum a, ToText a, Ord a) => Map a Text
    expensiveFunction = trace "expensive" $ mempty
    
    cheapFunction :: (Bounded a, Enum a, ToText a) => T a -> Map a Text -> Text
    cheapFunction (T x) !_ = trace "cheap" $ toText x
    
    instance (Bounded a, Enum a, ToText a, Ord a) => C (T a) where
      f = flip cheapFunction expensiveFunction
    
    module U
      ( U (..)
      )
    where
    
    import C
    
    newtype U = U Int
      deriving C via (T Int)
    

    Without optimisations I get this:

    f @(T Int)
    expensive
    cheap
    expensive
    cheap
    expensive
    cheap
    ["1","2","3"]
    
    
    f @U
    expensive
    cheap
    cheap
    cheap
    ["4","5","6"]
    
    
    g
    expensive
    cheap
    cheap
    cheap
    ["7","8","9"]
    

    So as you thought, a polymorphic instance like the one for T re-evaluates expensiveFunction each time (because f is a function of the dictionaries for Bounded a, etc, and the evaluation of expensiveFunction occurs inside the application of f to the instance dictionaries even though it's outside the application of f to its final argument).

    But the f for U is able to hold on to a single shared expensiveFunction, so it's only evaluated once even without any optimisations.

    And as shown with g, you can always define a monomorphic variant, which will then ensure that expensiveFunction is shared across all uses of the monomorphic variant.

    When compiling with optimisations I get this:

    f @(T Int)
    expensive
    cheap
    cheap
    cheap
    ["1","2","3"]
    
    
    f @U
    cheap
    cheap
    cheap
    ["4","5","6"]
    
    
    g
    cheap
    cheap
    cheap
    ["7","8","9"]
    

    The compiler has noticed that all of these are ultimately using expensiveFunction @Int, and shared it even across the instances for the different newtypes.

    With optimisations, I get this behaviour even without applying your flip optimisation.

    Personally, I would be willing to rely on the sharing of expensiveFunction across f @U calls (and g calls) in the unoptimised version. It conforms to what I would expect from a fairly naive mental model where sharing is determined by let-binding and scopes, so I consider it fairly unlikely that perturbations from changing the implementations or from future compiler versions would perform worse than this. So if sharing in monomorphic instances like C U and monomorphic wrappers like g is enough, then I'd just apply your flip optimisation and be content.

    I'd be a little less comfortable relying on the precise behaviour of the optimisations; they're probably affected by how much the compiler decides to inline, which can easily change in future compiler versions or as I make changes to the code. So if sharing across all the instances (and even for polymorphic instances like C (T a)) is critical, then I'd want to think about writing less "nice" code that makes the sharing more explicit. But if this extra sharing is simply a nice performance boost that is not mission-critical, then no need to worry too much.