4

I have been playing with JAX (automatic differentiation library in Python) and Zygote (the automatic differentiation library in Julia) to implement Gauss-Newton minimisation method. I came upon the @jit macro in Jax that runs my Python code in around 0.6 seconds compared to ~60 seconds for the version that does not use @jit. Julia ran the code in around 40 seconds. Is there an equivalent of @jit in Julia or Zygote that results is a better performance?

Here are the codes I used:

Python

from jax import grad, jit, jacfwd
import jax.numpy as jnp
import numpy as np
import time

def gaussian(x, params):
    amp = params[0]
    mu  = params[1]
    sigma = params[2]
    amplitude = amp/(jnp.abs(sigma)*jnp.sqrt(2*np.pi))
    arg = ((x-mu)/sigma)
    return amplitude*jnp.exp(-0.5*(arg**2))

def myjacobian(x, params):
    return jacfwd(gaussian, argnums = 1)(x, params)

def op(jac):
    return jnp.matmul(
        jnp.linalg.inv(jnp.matmul(jnp.transpose(jac),jac)),
        jnp.transpose(jac))
                         
def res(x, data, params):
    return data - gaussian(x, params)
@jit
def step(x, data, params):
    residuals = res(x, data, params)
    jacobian_operation = op(myjacobian(x, params))
    temp = jnp.matmul(jacobian_operation, residuals)
    return params + temp

N = 2000
x = np.linspace(start = -100, stop = 100, num= N)
data = gaussian(x, [5.65, 25.5, 37.23])

ini = jnp.array([0.9, 5., 5.0])
t1 = time.time()
for i in range(5000):
    ini = step(x, data, ini)
t2 = time.time()
print('t2-t1: ', t2-t1)
ini

Julia

using Zygote

function gaussian(x::Union{Vector{Float64}, Float64}, params::Vector{Float64})
    amp = params[1]
    mu  = params[2]
    sigma = params[3]
    
    amplitude = amp/(abs(sigma)*sqrt(2*pi))
    arg = ((x.-mu)./sigma)
    return amplitude.*exp.(-0.5.*(arg.^2))
    
end

function myjacobian(x::Vector{Float64}, params::Vector{Float64})
    output = zeros(length(x), length(params))
    for (index, ele) in enumerate(x)
        output[index,:] = collect(gradient((params)->gaussian(ele, params), params))[1]
    end
    return output
end

function op(jac::Matrix{Float64})
    return inv(jac'*jac)*jac'
end

function res(x::Vector{Float64}, data::Vector{Float64}, params::Vector{Float64})
    return data - gaussian(x, params)
end

function step(x::Vector{Float64}, data::Vector{Float64}, params::Vector{Float64})
    residuals = res(x, data, params)
    jacobian_operation = op(myjacobian(x, params))
    
    temp = jacobian_operation*residuals
    return params + temp
end

N = 2000
x = collect(range(start = -100, stop = 100, length= N))
params = vec([5.65, 25.5, 37.23])
data = gaussian(x, params)

ini = vec([0.9, 5., 5.0])
@time for i in range(start = 1, step = 1, length = 5000)
    ini = step(x, data, ini)
end
ini
MOON
  • 2,516
  • 4
  • 31
  • 49
  • 1
    I believe Julia is JIT-compiled by default. Although based on your results it looks like JAX (via Numba) is doing a better job. – DavidW Dec 04 '22 at 17:25
  • Small correction: JAX uses the XLA compiler, not Numba. – jakevdp Dec 04 '22 at 18:53
  • JAX may default to 32-bit floats here. I also recall some tricks needed to time it accurately -- are you sure 0.6 seconds is the final synchronised result, not the time to launch a kernel? – mcabbott Dec 04 '22 at 19:01
  • @mcabbott The Jax version is really faster. In less than one second after I evaluate the cell containing the fitting loop, it returns the fitted parameters. The 0.6 seconds may not be an accurate measurement, I used `time.time()`, but it is noticeably faster. – MOON Dec 04 '22 at 19:13
  • Yes I don't doubt that it helps, and 1s vs 60s of wall-clock time is hard to miss. Maybe `.block_until_ready()` is the thing I needed for `%timeit` not to be fooled? But haven't tried today. – mcabbott Dec 04 '22 at 19:30

2 Answers2

6

Your Julia code doing a number of things that aren't idiomatic and are worsening your performance. This won't be a full overview, but it should give you a good idea to start.

The first thing is passing params as a Vector is a bad idea. This means it will have to be heap allocated, and the compiler doesn't know how long it is. Instead, use a Tuple which will allow for a lot more optimization. Secondly, don't make gaussian act on a Vector of xs. Instead, write the scalar version and broadcast it. Specifically, with these changes, you will have

function gaussian(x::Number, params::NTuple{3, Float64})
    amp, mu, sigma = params
    
    # The next 2 lines should probably be done outside this function, but I'll leave them here for now.
    amplitude = amp/(abs(sigma)*sqrt(2*pi))
    arg = ((x-mu)/sigma)
    return amplitude*exp(-0.5*(arg^2))
end
Oscar Smith
  • 5,766
  • 1
  • 20
  • 34
  • Thanks! Using `Tuple` brought the running time to 2 seconds from 20 seconds! The broadcasting did not improve the running time. With regards to the tuple, where can I find similar tips? Is this particular case covered in here: https://docs.julialang.org/en/v1/manual/performance-tips/ – MOON Dec 04 '22 at 19:53
  • 1
    The broadcasting tip was just code clarity rather than performance (although it does make it harder to forget dots). I don't think the `Tuple` part is in the performance tips, but it perhaps should be. It's less of a big deal than other things like type stability, but it definitely can help. – Oscar Smith Dec 04 '22 at 20:04
4

One straightforward way to speed this up is to use ForwardDiff not Zygote, since you are taking a gradient of a vector of length 3, many times. Here this gets me from 16 to 3.5 seconds, with the last factor of 2 involving Chunk(3) to improve type-stability. Perhaps this can be improved further.

function myjacobian(x::Vector, params)
    # return rand(eltype(x), length(x), length(params))  # with no gradient, takes 0.5s
    output = zeros(eltype(x), length(x), length(params))
    config = ForwardDiff.GradientConfig(nothing, params, ForwardDiff.Chunk(3))
    for (i, xi) in enumerate(x)
        # grad = gradient(p->gaussian(xi, p), params)[1]       # original, takes 16s
        # grad = ForwardDiff.gradient(p-> gaussian(xi, p))     # ForwardDiff, takes 7s
        grad = ForwardDiff.gradient(p-> gaussian(xi, p), params, config)  # takes 3.5s
        copyto!(view(output,i,:), grad)  # this allows params::Tuple, OK for Zygote, no help
    end
    return output
end
# This needs gaussian.(x, Ref(params)) elsewhere to use on many x, same params
function gaussian(x::Real, params)
    # amp, mu, sigma = params  # with params::Vector this is slower, 19 sec
    amp = params[1]
    mu  = params[2]
    sigma = params[3]  # like this, 16 sec
    T = typeof(x)  # avoids having (2*pi)::Float64 promote everything
    amplitude = amp/(abs(sigma)*sqrt(2*T(pi)))
    arg = (x-mu)/sigma
    return amplitude * exp(-(arg^2)/2)
end

However, this is still computing many small gradient arrays in a loop. It could easily compute one big gradient array instead.

While in general Julia is happy to compile loops to something fast, loops that make individual arrays tend to be a bad idea. And this is especially true for Zygote, which is fastest on matlab-ish whole-array code.

Here's how this looks, it gets me under 1s for the whole program:

function gaussian(x::Real, amp::Real, mu::Real, sigma::Real)
    T = typeof(x)
    amplitude = amp/(abs(sigma)*sqrt(2*T(pi)))
    arg = (x-mu)/sigma
    return amplitude * exp(-(arg^2)/2)
end
function myjacobian2(x::Vector, params)  # with this, 0.9s
    amp = fill(params[1], length(x))
    mu  = fill(params[2], length(x))
    sigma = fill(params[3], length(x))  # use same sigma & different x value at each row:
    grads = gradient((amp, mu, sigma) -> sum(gaussian.(x, amp, mu, sigma)), amp, mu, sigma)
    hcat(grads...)
end
# Check that it agrees:
myjacobian2(x, params) ≈ myjacobian(x, params)

While this has little effect on the speed, I think you probably also want op(jac::Matrix) = Hermitian(jac'*jac) \ jac' rather than inv.

mcabbott
  • 2,329
  • 1
  • 4
  • 8
  • Oh I'm an idiot, I was timing Oscar's Tuple suggestion incorrectly. Now I get 0.52s using that, vs. 0.67s for this `myjacobian2` idea. (Down from 16s for the initial version.) I believe that the JAX version's `jacfwd(gaussian, argnums = 1)(x, params)` is a similar whole-array operation. – mcabbott Dec 04 '22 at 20:17
  • 1
    Why you didn't use `reduce(hcat, grads)`? And you're calculating `length(x)` three times (too much!!). I just took a look and saw these two. I'm sure that Julia can do it much faster, indeed. (without using any 3rd party package to enhance the performance!) – Shayan Dec 04 '22 at 21:02
  • `length(x)` is approximately free. Seems like it ought to be possible to avoid the `fill` though. `hcat(grads...)` is fine, it's a tuple of 3 items. – mcabbott Dec 04 '22 at 21:07
  • @mcabbott Do you get an speedup if you combine your approach with that of the Tuple suggestion? – MOON Dec 04 '22 at 21:23
  • 1
    Not sure. The four arguments `(x, amp, mu, sigma)` are essentially a tuple already. Since the gradient is only half the time now, the other operations start to matter, and perhaps a tuple storage works well. Could also consider StaticArrays. – mcabbott Dec 05 '22 at 00:46