Search code examples
juliageneric-programming

Type stability for a function involving case distinctions


I am writing a function which computes the weights for the barycentric interpolation formula. Ignoring type stability, that's easy enough:

function baryweights(x)
    n = length(x)
    if n == 1; return [1.0]; end # This is obviously not type stable

    xmin,xmax = extrema(x)
    x *= 4/(xmax-xmin)
    # ^ Multiply by capacity of interval to avoid overflow
    return [
        1/prod(x[i]-x[j] for j in 1:n if j != i)
        for i = 1:n
    ]
end

The problem for type stability is to work out the return type of the n > 1 case so I can return an array of the correct type in the n == 1 case. Is there an easy trick to achieve this?


Solution

  • Simply call the function recursively on a dummy argument:

    function baryweights(x)
        n = length(x)
        if n == 1
            T = eltype(baryweights(zeros(eltype(x),2)))
            return [one(T)]
        end
    
        xmin,xmax = extrema(x)
        let x = 4/(xmax-xmin) * x
            # ^ Multiply by capacity of interval to avoid overflow,
            #   and wrap in let to avoid another source of type instability
            #   (https://github.com/JuliaLang/julia/issues/15276)
            return [
                1/prod(x[i]-x[j] for j in 1:n if j != i)
                for i = 1:n
            ]
        end
    end