Search code examples
hashf#compiler-constructionabstract-syntax-treememoization

How do I cache hash codes for an AST?


I am working on a language in F# and upon testing, I find that the runtime spends over 90% of its time comparing for equality. Because of that the language is so slow as to be unusable. During instrumentation, the GetHashCode function shows fairly high up on the list as a source of overhead. What is going on is that during method calls, I am using method bodies (Expr) along with the call arguments as keys in a dictionary and that triggers repeated traversals over the AST segments.

To improve performance I'd like to add memoization nodes in the AST.

type Expr =
| Add of Expr * Expr
| Lit of int
| HashNode of int * Expr

In the above simplified example, what I would like is that the HashNode represent the hash of its Expr, so that the GetHashCode does not have to travel any deeper in the AST in order to calculate it.

That having said, I am not sure how I should override the GetHashCode method. Ideally, I'll like to reuse the inbuilt hash method and make it ignore only the HashNode somehow, but I am not sure how to do that.

More likely, I am going to have to make my own hash function, but unfortunately I know nothing about hash functions so I am a bit lost right now.

An alternative idea that I have would be to replace nodes with unique IDs while keeping that hash function as it is, but that would introduce additional complexities into the code that I'd rather avoid unless I have to.


Solution

  • I needed a similar thing recently in TheGamma (GitHub) where I build a dependency graph (kind of like AST) that gets recreated very often (when you change code in editor and it gets re-parsed), but I have live previews that may take some time to calculate, so I wanted to reuse as much of the previous graph as possible.

    The way I'm doing that is that I attach a "symbol" to each node. Two nodes with the same symbol are equal, which I think you could use for efficient equality testing:

    type Expr =
      | Add of ExprNode * ExprNode
      | Lit of int
    
    and ExprNode(expr:Expr, symbol:int) = 
      member x.Expression = expr
      member x.Symbol = symbol
      override x.GetHashCode() = symbol
      override x.Equals(y) = 
        match y with 
        | :? ExprNode as y -> y.Symbol = x.Symbol
        | _ -> false
    

    I do keep a cache of nodes - the key is some code of the node kind (0 for Add, 1 for Lit, etc.) and symbols of all nested nodes. For literals, I also add the number itself, which will mean that creating the same literal twice will give you the same node. So creating a node looks like this:

    let node expr ctx =  
      // Get the key from the kind of the expression
      // and symbols of all nested node in this expression
      let key = 
        match expr with 
        | Lit n -> [0; n]
        | Add(e1, e2) -> [1; e1.Symbol; e2.Symbol]
      // Return either a node from cache or create a new one
      match ListDictionary.tryFind key ctx with
      | Some res -> res
      | None ->
          let res = ExprNode(expr, nextId())
          ListDictionary.set key res ctx
          res
    

    The ListDictionary module is a mutable dictionary where the key is a list of integers and nextId is the usual function to generate next ID:

    type ListDictionaryNode<'K, 'T> = 
      { mutable Result : 'T option
        Nested : Dictionary<'K, ListDictionaryNode<'K, 'T>> }
    
    type ListDictionary<'K, 'V> = Dictionary<'K, ListDictionaryNode<'K, 'V>>
    
    [<CompilationRepresentation(CompilationRepresentationFlags.ModuleSuffix)>]
    module ListDictionary = 
      let tryFind ks dict = 
        let rec loop ks node =
          match ks, node with
          | [], { Result = Some r } -> Some r
          | k::ks, { Nested = d } when d.ContainsKey k -> loop ks (d.[k])
          | _ -> None
        loop ks { Nested = dict; Result = None }
    
      let set ks v dict =
        let rec loop ks (dict:ListDictionary<_, _>) = 
          match ks with
          | [] -> failwith "Empty key not supported"
          | k::ks ->
              if not (dict.ContainsKey k) then 
                dict.[k] <- { Nested = Dictionary<_, _>(); Result = None }
              if List.isEmpty ks then dict.[k].Result <- Some v
              else loop ks (dict.[k].Nested)
        loop ks dict
    
    
    let nextId = 
      let mutable id = 0
      fun () -> id <- id + 1; id
    

    So, I guess I'm saying that you'll need to implement your own caching mechanism, but this worked quite well for me and may hint at how to do this in your case!