Search code examples
optimizationjuliajit

Reducing JIT time with recursive calls in Julia


I have a recursive function which operates a binary tree of integers, implemented as a nested pair of pairs or ints. My function creates a new tree with a different structure, and calls itself recursively until some condition is met. The issue I'm finding is that the first time the code is run, it takes a really long time to JIT compile all the possible signatures of the function; afterwards it runs fine.

Here is minimal working example:

my_tree = ((((6 => 7) => (6 => 7)) => ((7 => 7) => (0 => 7))) => (((8 => 7) => (7 => 7)) => ((8 => 8) => (8 => 0)))) => ((((2 => 4) => 7) => (6 => (0 => 5))) => (((6 => 8) => (2 => 8)) => ((2 => 1) => (4 => 5))))

function tree_reduce(tree::Pair)
    left, right = tree
    left isa Pair && (left = tree_reduce(left))
    right isa Pair && (right = tree_reduce(right))
    return left + right
end

@show my_tree
@show tree_reduce(my_tree)

using MethodAnalysis
methods = methodinstances(tree_reduce)
@show length(methods)

Although this example is not perceptually slow, it still generates 9 method instances for:

tree_reduce(::Pair{Pair{Pair{Int64, Int64}, Pair{Int64, Int64}}, Pair{Pair{Int64, Int64}, Pair{Int64, Int64}}})
tree_reduce(::Pair{Pair{Int64, Int64}, Pair{Int64, Int64}})
tree_reduce(::Pair{Int64, Int64})
tree_reduce(::Pair{Pair{Pair{Pair{Int64, Int64}, Int64}, Pair{Int64, Pair{Int64, Int64}}}, Pair{Pair{Pair{Int64, Int64}, Pair{Int64, Int64}}, Pair{Pair{Int64, Int64}, Pair{Int64, Int64}}}})
etc ...

Is there a way of avoiding this / precompiling / speeding it up / writing a generic function / running particular (part of) a function in an interpreted mode? I would be prepared to make the overall performance of the code slightly works at the pice of having it run faster on the first top-level call to tree_reduce.


Solution

  • @nospecialize is an option, but I think in this case, a separate data structure which doesn't capture the whole data structure in its type is in order. Pair is really thought for strongly typed pairs of things, not for large nested structures.

    julia> abstract type BiTree{T} end
    julia> struct Branch{T} <: BiTree{T} 
               left::BiTree{T}
               right::BiTree{T}
           end
    
    julia> struct Leaf{T} <: BiTree{T}
               value::T
           end
    
    julia> Base.foldl(f, b::Branch) = f(foldl(f, b.left), foldl(f, b.right))
    
    julia> Base.foldl(f, l::Leaf) = f(l.value)
    
    julia> (→)(l::Branch, r::Branch) = Branch(l, r) # just for illustration
    → (generic function with 1 method)
    
    julia> (→)(l, r::Branch) = Branch(Leaf(l), r)
    → (generic function with 2 methods)
    
    julia> (→)(l::Branch, r) = Branch(l, Leaf(r))
    → (generic function with 3 methods)
    
    julia> (→)(l, r) = Branch(Leaf(l), Leaf(r))
    → (generic function with 4 methods)
    
    julia> my_tree = ((((6 → 7) → (6 → 7)) → ((7 → 7) → (0 → 7))) → (((8 → 7) → (7 → 7)) → ((8 → 8) → (8 → 0)))) → ((((2 → 4) → 7) → (6 → (0 → 5))) → (((6 → 8) → (2 → 8)) → ((2 → 1) → (4 → 5))));
    
    julia> typeof(my_tree)
    Branch{Int64}
    
    julia> foldl(+, my_tree)
    160
    

    This also has the advantage that you can without danger of breaking anything overload other methods, such as for printing or indexing.