Search code examples
julialist-comprehensiontype-stability

Why moving a foor loop to list comprehension makes a function type-unstable?


Look at these functions:

function fA1(s::AbstractString, n)
    T = getfield(Base, Symbol(s)) 
    x = one(T)
    for i in 1:n
        x = x+x
    end
    return x
end
function fA2(s::AbstractString, n)
    T = getfield(Base, Symbol(s)) 
    x = one(Float64)
    for i in 1:n
        x = x+x
    end
    return x
end
function fB1(s::AbstractString, n)
    T = getfield(Base, Symbol(s)) 
    x = one(T)
    [x = x+x for i in 1:n]
    return x
end
function fB2(s::AbstractString, n)
    T = getfield(Base, Symbol(s)) 
    x = one(Float64)
    [x = x+x for i in 1:n]
    return x
end

fA1 is type-unstable and slow, fA2 is type-stable and fast. When I however move the for loop as a list-comprehension, both fB1 and fB2 are type-unstable and slow, while the numerical result remains (obviously) the same.

Why is that ?


Solution

  • The reason is explained here in the Julia manual.

    It is important to know that comprehension creates a a new local scope as is explained here in the Julia manual.

    In this case, if you stick to updating a variable from outer scope inside a comprehension (which in general is not recommended as it is typically confusing for people that read such code), the best you can do as far as I know (maybe someone can come up with a better solution but I think this is unlikely given the current state of the Julia compiler) is to use type annotation:

    function fB2(n)
        x::Float64 = one(Float64)
        [x = x+x for i in 1:n]
        return x
    end
    

    This will not avoid boxing, but should make the return type inferrable and the performance should significantly improve.

    In the future is is very likely that the Julia compiler will be smart enough to handle such code without requiring type annotation.

    Here is a performance comparison:

    julia> using BenchmarkTools
    
    julia> function f_fast(n)
               x::Float64 = one(Float64)
               [x = x+x for i in 1:n]
               return x
           end
    f_fast (generic function with 1 method)
    
    julia> function f_slow(n)
               x = one(Float64)
               [x = x+x for i in 1:n]
               return x
           end
    f_slow (generic function with 1 method)
    
    julia> @benchmark f_fast(1000)
    BenchmarkTools.Trial:
      memory estimate:  23.63 KiB
      allocs estimate:  1004
      --------------
      minimum time:     4.357 μs (0.00% GC)
      median time:      7.257 μs (0.00% GC)
      mean time:        10.314 μs (16.54% GC)
      maximum time:     5.256 ms (99.86% GC)
      --------------
      samples:          10000
      evals/sample:     7
    
    julia> @benchmark f_slow(1000)
    BenchmarkTools.Trial:
      memory estimate:  23.66 KiB
      allocs estimate:  1005
      --------------
      minimum time:     17.899 μs (0.00% GC)
      median time:      26.300 μs (0.00% GC)
      mean time:        34.916 μs (15.56% GC)
      maximum time:     36.220 ms (99.91% GC)
      --------------
      samples:          10000
      evals/sample:     1