Search code examples
typesfloating-pointjuliatype-stability

Type stable functions in Julia for general distribution when working with floating point inputs


In Julia, I have a function like this:

function f(x::Float64, c::Float64)
    if x <= 0
        return(0.0)
    elseif x <= c
        return(x / c)
    else
        return(1.0)
    end
end

The function is type-stable and so will run quickly. However, I want to include the function in a package for general distribution, including to 32-bit machines. What is best practice here? Should I write another version of the function for Float32 (this could get annoying if I have many such functions)? Could I use FloatingPoint for the input types instead? If I do this, how do I ensure the function remains type-stable, since 0.0 and 1.0 are Float64, while x / c could potentially be Float32 or Float16? Maybe I could use type parameters, e.g. T<:FloatingPoint and then let x::T and c::T, and then use 0.0::T and 1.0::T in the body of the function to ensure it is type-stable?

Any guidance here would be appreciated.


Solution

  • The one and zero functions are useful here:

    function f(x, c)
        if x <= 0
            return zero(x)
        elseif x <= c
            return x/c
        else
            return one(x)
        end
    end
    

    This version is a bit more strict about the input types:

    function f{T<:FloatingPoint}(x::T, c::T)
        if x <= 0
            return zero(T)
        elseif x <= c
            return x/c
        else
            return one(T)
        end
    end