Search code examples
typesjuliageneric-programming

type-stability issue when a function adds a dimension to the input argument


I have a function with returns an array of the same element-type than the input array but with an additional dimension. Here is a simple example:

function myfun(a::Array{T,N}) where {T,N}
   b = Array{T,N+1}(size(a)...,2)
   b[:] = 42
   return b
end

When this function is called on a 2x2 arrays, it returns an 2x2x2 array.

myfun(zeros(2,2))
2×2×2 Array{Float64,3}:
[:, :, 1] =
 42.0  42.0
 42.0  42.0

[:, :, 2] =
 42.0  42.0
 42.0  42.0

However, this function is not type-stable. According to @code_warntype, b is of type Any.

Even with type-annotation on b, the result is not type-stable with respect to the number of dimensions:

function myfun(a::Array{T,N}) where {T,N}
      b = Array{T,N+1}(size(a)...,2) :: Array{T,N+1}
      b[:] = T(42)
      return b
end

@code_warntype myfun(zeros(2,2)) returns now Array{Float64,_} where _ for the type of b. Should Julia not be able to figure out that the number of dimensions is 3 when the input argument has 2 dimensions?

I am using julia 0.6.2 (on linux).


Solution

  • That's due to the fact that the constructor(Array{T,N+1}(size(a)...,2)) is executed at runtime, you can use @generated functions to precompute N at compile time:

    julia> @generated function myfun(a::Array{T,N}) where {T,N}
               NN = N+1
               quote 
                   b = Array{$T,$NN}(size(a)...,2)
                   b[:] = 42
                   return b
               end
           end
    myfun (generic function with 1 method)
    
    julia> @code_warntype myfun(zeros(2,2))
    Variables:
      #self# <optimized out>
      a::Array{Float64,2}
      b::Array{Float64,3}
    
    Body:
      begin  # line 2:
          # meta: location REPL[1] # line 4:
          SSAValue(2) = (Base.arraysize)(a::Array{Float64,2}, 1)::Int64
          SSAValue(1) = (Base.arraysize)(a::Array{Float64,2}, 2)::Int64
          b::Array{Float64,3} = $(Expr(:foreigncall, :(:jl_alloc_array_3d), Array{Float64,3}, svec(Any, Int64, Int64, Int64), Array{Float64,3}, 0, SSAValue(2), 0, SSAValue(1), 0, :($(QuoteNode(2))), 0)) # line 5:
          $(Expr(:invoke, MethodInstance for fill!(::Array{Float64,3}, ::Int64), :(Base.fill!), :(b), 42))
          # meta: pop location
          return b::Array{Float64,3}
      end::Array{Float64,3}
    
    julia> myfun(zeros(2,2))
    2×2×2 Array{Float64,3}:
    [:, :, 1] =
     42.0  42.0
     42.0  42.0
    
    [:, :, 2] =
     42.0  42.0
     42.0  42.0