Search code examples
debuggingiteratorjuliatype-stability

How to avoid memory allocations in custom Julia iterators?


Consider the following Julia "compound" iterator: it merges two iterators, a and b, each of which are assumed to be sorted according to order, to a single ordered sequence:

struct MergeSorted{T,A,B,O}
    a::A
    b::B
    order::O

    MergeSorted(a::A, b::B, order::O=Base.Order.Forward) where {A,B,O} =
        new{promote_type(eltype(A),eltype(B)),A,B,O}(a, b, order)
end

Base.eltype(::Type{MergeSorted{T,A,B,O}}) where {T,A,B,O} = T

@inline function Base.iterate(self::MergeSorted{T}, 
                      state=(iterate(self.a), iterate(self.b))) where T
    a_result, b_result = state
    if b_result === nothing
        a_result === nothing && return nothing
        a_curr, a_state = a_result
        return T(a_curr), (iterate(self.a, a_state), b_result)
    end

    b_curr, b_state = b_result
    if a_result !== nothing
        a_curr, a_state = a_result
        Base.Order.lt(self.order, a_curr, b_curr) &&
            return T(a_curr), (iterate(self.a, a_state), b_result)
    end
    return T(b_curr), (a_result, iterate(self.b, b_state))
end

This code works, but is type-instable since the Julia iteration facilities are inherently so. For most cases, the compiler can work this out automatically, however, here it does not work: the following test code illustrates that temporaries are created:

>>> x = MergeSorted([1,4,5,9,32,44], [0,7,9,24,134]);
>>> sum(x);
>>> @time sum(x);
0.000013 seconds (61 allocations: 2.312 KiB)

Note the allocation count.

Is there any way to efficiently debug such situations other than playing around with the code and hoping that the compiler will be able to optimize out the type ambiguities? Does anyone know there any solution in this specific case that does not create temporaries?


Solution

  • How to diagnose the problem?

    Answer: use @code_warntype

    Run:

    julia> @code_warntype iterate(x, iterate(x)[2])
    Variables
      #self#::Core.Const(iterate)
      self::MergeSorted{Int64, Vector{Int64}, Vector{Int64}, Base.Order.ForwardOrdering}
      state::Tuple{Tuple{Int64, Int64}, Tuple{Int64, Int64}}
      @_4::Int64
      @_5::Int64
      @_6::Union{}
      @_7::Int64
      b_state::Int64
      b_curr::Int64
      a_state::Int64
      a_curr::Int64
      b_result::Tuple{Int64, Int64}
      a_result::Tuple{Int64, Int64}
    
    Body::Tuple{Int64, Any}
    1 ─       nothing
    │         Core.NewvarNode(:(@_4))
    │         Core.NewvarNode(:(@_5))
    │         Core.NewvarNode(:(@_6))
    │         Core.NewvarNode(:(b_state))
    │         Core.NewvarNode(:(b_curr))
    │         Core.NewvarNode(:(a_state))
    │         Core.NewvarNode(:(a_curr))
    │   %9  = Base.indexed_iterate(state, 1)::Core.PartialStruct(Tuple{Tuple{Int64, Int64}, Int64}, Any[Tuple{Int64, Int64}, Core.Const(2)])
    │         (a_result = Core.getfield(%9, 1))
    │         (@_7 = Core.getfield(%9, 2))
    │   %12 = Base.indexed_iterate(state, 2, @_7::Core.Const(2))::Core.PartialStruct(Tuple{Tuple{Int64, Int64}, Int64}, Any[Tuple{Int64, Int64}, Core.Const(3)])
    │         (b_result = Core.getfield(%12, 1))
    │   %14 = (b_result === Main.nothing)::Core.Const(false)
    └──       goto #3 if not %14
    2 ─       Core.Const(:(a_result === Main.nothing))
    │         Core.Const(:(%16))
    │         Core.Const(:(return Main.nothing))
    │         Core.Const(:(Base.indexed_iterate(a_result, 1)))
    │         Core.Const(:(a_curr = Core.getfield(%19, 1)))
    │         Core.Const(:(@_6 = Core.getfield(%19, 2)))
    │         Core.Const(:(Base.indexed_iterate(a_result, 2, @_6)))
    │         Core.Const(:(a_state = Core.getfield(%22, 1)))
    │         Core.Const(:(($(Expr(:static_parameter, 1)))(a_curr)))
    │         Core.Const(:(Base.getproperty(self, :a)))
    │         Core.Const(:(Main.iterate(%25, a_state)))
    │         Core.Const(:(Core.tuple(%26, b_result)))
    │         Core.Const(:(Core.tuple(%24, %27)))
    └──       Core.Const(:(return %28))
    3 ┄ %30 = Base.indexed_iterate(b_result, 1)::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(2)])
    │         (b_curr = Core.getfield(%30, 1))
    │         (@_5 = Core.getfield(%30, 2))
    │   %33 = Base.indexed_iterate(b_result, 2, @_5::Core.Const(2))::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(3)])
    │         (b_state = Core.getfield(%33, 1))
    │   %35 = (a_result !== Main.nothing)::Core.Const(true)
    └──       goto #6 if not %35
    4 ─ %37 = Base.indexed_iterate(a_result, 1)::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(2)])
    │         (a_curr = Core.getfield(%37, 1))
    │         (@_4 = Core.getfield(%37, 2))
    │   %40 = Base.indexed_iterate(a_result, 2, @_4::Core.Const(2))::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(3)])
    │         (a_state = Core.getfield(%40, 1))
    │   %42 = Base.Order::Core.Const(Base.Order)
    │   %43 = Base.getproperty(%42, :lt)::Core.Const(Base.Order.lt)
    │   %44 = Base.getproperty(self, :order)::Core.Const(Base.Order.ForwardOrdering())
    │   %45 = a_curr::Int64
    │   %46 = (%43)(%44, %45, b_curr)::Bool
    └──       goto #6 if not %46
    5 ─ %48 = ($(Expr(:static_parameter, 1)))(a_curr)::Int64
    │   %49 = Base.getproperty(self, :a)::Vector{Int64}
    │   %50 = Main.iterate(%49, a_state)::Union{Nothing, Tuple{Int64, Int64}}
    │   %51 = Core.tuple(%50, b_result)::Tuple{Union{Nothing, Tuple{Int64, Int64}}, Tuple{Int64, Int64}}
    │   %52 = Core.tuple(%48, %51)::Tuple{Int64, Tuple{Union{Nothing, Tuple{Int64, Int64}}, Tuple{Int64, Int64}}}
    └──       return %52
    6 ┄ %54 = ($(Expr(:static_parameter, 1)))(b_curr)::Int64
    │   %55 = a_result::Tuple{Int64, Int64}
    │   %56 = Base.getproperty(self, :b)::Vector{Int64}
    │   %57 = Main.iterate(%56, b_state)::Union{Nothing, Tuple{Int64, Int64}}
    │   %58 = Core.tuple(%55, %57)::Tuple{Tuple{Int64, Int64}, Union{Nothing, Tuple{Int64, Int64}}}
    │   %59 = Core.tuple(%54, %58)::Tuple{Int64, Tuple{Tuple{Int64, Int64}, Union{Nothing, Tuple{Int64, Int64}}}}
    └──       return %59
    

    and you see that there are too many types of return value, so Julia gives up specializing them (and just assumes the second element of return type is Any).

    How to fix the problem?

    Answer: reduce the number of return type options of iterate.

    Here is a quick write up (I do not claim it is most terse and have not tested it extensively so there might be some bug, but it was simple enough to write quickly using your code to show how one could approach your problem; note that I use special branches when one of the collections is empty as then it should be faster to just iterate one collection):

    struct MergeSorted{T,A,B,O,F1,F2}
        a::A
        b::B
        order::O
        fa::F1
        fb::F2
        function MergeSorted(a::A, b::B, order::O=Base.Order.Forward) where {A,B,O}
            fa, fb = iterate(a), iterate(b)
            F1 = typeof(fa)
            F2 = typeof(fb)
            new{promote_type(eltype(A),eltype(B)),A,B,O,F1,F2}(a, b, order, fa, fb)
        end
    end
    
    Base.eltype(::Type{MergeSorted{T,A,B,O}}) where {T,A,B,O} = T
    
    struct State{Ta, Tb}
        a::Union{Nothing, Ta}
        b::Union{Nothing, Tb}
    end
    
    function Base.iterate(self::MergeSorted{T,A,B,O,Nothing,Nothing}) where {T,A,B,O}
        return nothing
    end
    
    function Base.iterate(self::MergeSorted{T,A,B,O,F1,Nothing}) where {T,A,B,O,F1}
        return self.fa
    end
    
    function Base.iterate(self::MergeSorted{T,A,B,O,F1,Nothing}, state) where {T,A,B,O,F1}
        return iterate(self.a, state)
    end
    
    function Base.iterate(self::MergeSorted{T,A,B,O,Nothing,F2}) where {T,A,B,O,F2}
        return self.fb
    end
    
    function Base.iterate(self::MergeSorted{T,A,B,O,Nothing,F2}, state) where {T,A,B,O,F2}
        return iterate(self.b, state)
    end
    
    @inline function Base.iterate(self::MergeSorted{T,A,B,O,F1,F2}) where {T,A,B,O,F1,F2}
        a_result, b_result = self.fa, self.fb
        return iterate(self, State{F1,F2}(a_result, b_result))
    end
    
    @inline function Base.iterate(self::MergeSorted{T,A,B,O,F1,F2}, 
        state::State{F1,F2}) where {T,A,B,O,F1,F2}
        a_result, b_result = state.a, state.b
    
        if b_result === nothing
            a_result === nothing && return nothing
            a_curr, a_state = a_result
            return T(a_curr), State{F1,F2}(iterate(self.a, a_state), b_result)
        end
    
        b_curr, b_state = b_result
        if a_result !== nothing
            a_curr, a_state = a_result
            Base.Order.lt(self.order, a_curr, b_curr) &&
                return T(a_curr), State{F1,F2}(iterate(self.a, a_state), b_result)
        end
        return T(b_curr), State{F1,F2}(a_result, iterate(self.b, b_state))
    end
    

    And now you have:

    julia> x = MergeSorted([1,4,5,9,32,44], [0,7,9,24,134]);
    
    julia> sum(x)
    269
    
    julia> @allocated sum(x)
    0
    
    julia> @code_warntype iterate(x, iterate(x)[2])
    Variables
      #self#::Core.Const(iterate)
      self::MergeSorted{Int64, Vector{Int64}, Vector{Int64}, Base.Order.ForwardOrdering, Tuple{Int64, Int64}, Tuple{Int64, Int64}}
      state::State{Tuple{Int64, Int64}, Tuple{Int64, Int64}}
      @_4::Int64
      @_5::Int64
      @_6::Int64
      b_state::Int64
      b_curr::Int64
      a_state::Int64
      a_curr::Int64
      b_result::Union{Nothing, Tuple{Int64, Int64}}
      a_result::Union{Nothing, Tuple{Int64, Int64}}
    
    Body::Union{Nothing, Tuple{Int64, State{Tuple{Int64, Int64}, Tuple{Int64, Int64}}}}
    1 ─       nothing
    │         Core.NewvarNode(:(@_4))
    │         Core.NewvarNode(:(@_5))
    │         Core.NewvarNode(:(@_6))
    │         Core.NewvarNode(:(b_state))
    │         Core.NewvarNode(:(b_curr))
    │         Core.NewvarNode(:(a_state))
    │         Core.NewvarNode(:(a_curr))
    │   %9  = Base.getproperty(state, :a)::Union{Nothing, Tuple{Int64, Int64}}
    │   %10 = Base.getproperty(state, :b)::Union{Nothing, Tuple{Int64, Int64}}
    │         (a_result = %9)
    │         (b_result = %10)
    │   %13 = (b_result === Main.nothing)::Bool
    └──       goto #5 if not %13
    2 ─ %15 = (a_result === Main.nothing)::Bool
    └──       goto #4 if not %15
    3 ─       return Main.nothing
    4 ─ %18 = Base.indexed_iterate(a_result::Tuple{Int64, Int64}, 1)::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(2)])
    │         (a_curr = Core.getfield(%18, 1))
    │         (@_6 = Core.getfield(%18, 2))
    │   %21 = Base.indexed_iterate(a_result::Tuple{Int64, Int64}, 2, @_6::Core.Const(2))::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(3)])
    │         (a_state = Core.getfield(%21, 1))
    │   %23 = ($(Expr(:static_parameter, 1)))(a_curr)::Int64
    │   %24 = Core.apply_type(Main.State, $(Expr(:static_parameter, 5)), $(Expr(:static_parameter, 6)))::Core.Const(State{Tuple{Int64, Int64}, Tuple{Int64, Int64}})
    │   %25 = Base.getproperty(self, :a)::Vector{Int64}
    │   %26 = Main.iterate(%25, a_state)::Union{Nothing, Tuple{Int64, Int64}}
    │   %27 = (%24)(%26, b_result::Core.Const(nothing))::State{Tuple{Int64, Int64}, Tuple{Int64, Int64}}
    │   %28 = Core.tuple(%23, %27)::Tuple{Int64, State{Tuple{Int64, Int64}, Tuple{Int64, Int64}}}
    └──       return %28
    5 ─ %30 = Base.indexed_iterate(b_result::Tuple{Int64, Int64}, 1)::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(2)])
    │         (b_curr = Core.getfield(%30, 1))
    │         (@_5 = Core.getfield(%30, 2))
    │   %33 = Base.indexed_iterate(b_result::Tuple{Int64, Int64}, 2, @_5::Core.Const(2))::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(3)])
    │         (b_state = Core.getfield(%33, 1))
    │   %35 = (a_result !== Main.nothing)::Bool
    └──       goto #8 if not %35
    6 ─ %37 = Base.indexed_iterate(a_result::Tuple{Int64, Int64}, 1)::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(2)])
    │         (a_curr = Core.getfield(%37, 1))
    │         (@_4 = Core.getfield(%37, 2))
    │   %40 = Base.indexed_iterate(a_result::Tuple{Int64, Int64}, 2, @_4::Core.Const(2))::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(3)])
    │         (a_state = Core.getfield(%40, 1))
    │   %42 = Base.Order::Core.Const(Base.Order)
    │   %43 = Base.getproperty(%42, :lt)::Core.Const(Base.Order.lt)
    │   %44 = Base.getproperty(self, :order)::Core.Const(Base.Order.ForwardOrdering())
    │   %45 = a_curr::Int64
    │   %46 = (%43)(%44, %45, b_curr)::Bool
    └──       goto #8 if not %46
    7 ─ %48 = ($(Expr(:static_parameter, 1)))(a_curr)::Int64
    │   %49 = Core.apply_type(Main.State, $(Expr(:static_parameter, 5)), $(Expr(:static_parameter, 6)))::Core.Const(State{Tuple{Int64, Int64}, Tuple{Int64, Int64}})
    │   %50 = Base.getproperty(self, :a)::Vector{Int64}
    │   %51 = Main.iterate(%50, a_state)::Union{Nothing, Tuple{Int64, Int64}}
    │   %52 = (%49)(%51, b_result::Tuple{Int64, Int64})::State{Tuple{Int64, Int64}, Tuple{Int64, Int64}}
    │   %53 = Core.tuple(%48, %52)::Tuple{Int64, State{Tuple{Int64, Int64}, Tuple{Int64, Int64}}}
    └──       return %53
    8 ┄ %55 = ($(Expr(:static_parameter, 1)))(b_curr)::Int64
    │   %56 = Core.apply_type(Main.State, $(Expr(:static_parameter, 5)), $(Expr(:static_parameter, 6)))::Core.Const(State{Tuple{Int64, Int64}, Tuple{Int64, Int64}})
    │   %57 = a_result::Union{Nothing, Tuple{Int64, Int64}}
    │   %58 = Base.getproperty(self, :b)::Vector{Int64}
    │   %59 = Main.iterate(%58, b_state)::Union{Nothing, Tuple{Int64, Int64}}
    │   %60 = (%56)(%57, %59)::State{Tuple{Int64, Int64}, Tuple{Int64, Int64}}
    │   %61 = Core.tuple(%55, %60)::Tuple{Int64, State{Tuple{Int64, Int64}, Tuple{Int64, Int64}}}
    └──       return %61
    

    EDIT: now I have realized that my implementation is not fully correct, as it assumes that the return value of iterate if it is not nothing is type stable (which it does not have to be). But if it is not type stable then compiler must allocate. So a fully correct solution would first check if iterate is type stable. If it is - use my solution, and if it is not - use e.g. your solution.