Search code examples
pythonoptimizationjuliajax

How to improve Julia's performance using just in time compilation (JIT)


I have been playing with JAX (automatic differentiation library in Python) and Zygote (the automatic differentiation library in Julia) to implement Gauss-Newton minimisation method. I came upon the @jit macro in Jax that runs my Python code in around 0.6 seconds compared to ~60 seconds for the version that does not use @jit. Julia ran the code in around 40 seconds. Is there an equivalent of @jit in Julia or Zygote that results is a better performance?

Here are the codes I used:

Python

from jax import grad, jit, jacfwd
import jax.numpy as jnp
import numpy as np
import time

def gaussian(x, params):
    amp = params[0]
    mu  = params[1]
    sigma = params[2]
    amplitude = amp/(jnp.abs(sigma)*jnp.sqrt(2*np.pi))
    arg = ((x-mu)/sigma)
    return amplitude*jnp.exp(-0.5*(arg**2))

def myjacobian(x, params):
    return jacfwd(gaussian, argnums = 1)(x, params)

def op(jac):
    return jnp.matmul(
        jnp.linalg.inv(jnp.matmul(jnp.transpose(jac),jac)),
        jnp.transpose(jac))
                         
def res(x, data, params):
    return data - gaussian(x, params)
@jit
def step(x, data, params):
    residuals = res(x, data, params)
    jacobian_operation = op(myjacobian(x, params))
    temp = jnp.matmul(jacobian_operation, residuals)
    return params + temp

N = 2000
x = np.linspace(start = -100, stop = 100, num= N)
data = gaussian(x, [5.65, 25.5, 37.23])

ini = jnp.array([0.9, 5., 5.0])
t1 = time.time()
for i in range(5000):
    ini = step(x, data, ini)
t2 = time.time()
print('t2-t1: ', t2-t1)
ini

Julia

using Zygote

function gaussian(x::Union{Vector{Float64}, Float64}, params::Vector{Float64})
    amp = params[1]
    mu  = params[2]
    sigma = params[3]
    
    amplitude = amp/(abs(sigma)*sqrt(2*pi))
    arg = ((x.-mu)./sigma)
    return amplitude.*exp.(-0.5.*(arg.^2))
    
end

function myjacobian(x::Vector{Float64}, params::Vector{Float64})
    output = zeros(length(x), length(params))
    for (index, ele) in enumerate(x)
        output[index,:] = collect(gradient((params)->gaussian(ele, params), params))[1]
    end
    return output
end

function op(jac::Matrix{Float64})
    return inv(jac'*jac)*jac'
end

function res(x::Vector{Float64}, data::Vector{Float64}, params::Vector{Float64})
    return data - gaussian(x, params)
end

function step(x::Vector{Float64}, data::Vector{Float64}, params::Vector{Float64})
    residuals = res(x, data, params)
    jacobian_operation = op(myjacobian(x, params))
    
    temp = jacobian_operation*residuals
    return params + temp
end

N = 2000
x = collect(range(start = -100, stop = 100, length= N))
params = vec([5.65, 25.5, 37.23])
data = gaussian(x, params)

ini = vec([0.9, 5., 5.0])
@time for i in range(start = 1, step = 1, length = 5000)
    ini = step(x, data, ini)
end
ini

Solution

  • Your Julia code doing a number of things that aren't idiomatic and are worsening your performance. This won't be a full overview, but it should give you a good idea to start.

    The first thing is passing params as a Vector is a bad idea. This means it will have to be heap allocated, and the compiler doesn't know how long it is. Instead, use a Tuple which will allow for a lot more optimization. Secondly, don't make gaussian act on a Vector of xs. Instead, write the scalar version and broadcast it. Specifically, with these changes, you will have

    function gaussian(x::Number, params::NTuple{3, Float64})
        amp, mu, sigma = params
        
        # The next 2 lines should probably be done outside this function, but I'll leave them here for now.
        amplitude = amp/(abs(sigma)*sqrt(2*pi))
        arg = ((x-mu)/sigma)
        return amplitude*exp(-0.5*(arg^2))
    end