5

I am trying to numerically Solve an ODE that admits discrete jumps. I am using the Euler Method and was hoping that Numba's jit might help me to speed up the process (right now the script takes 300s to run and I need it to run 200 times).

Here is my simplified first attempt:

import numpy as np
from numba import jit

dt = 1e-5
T = 1
x0 = 1
noiter = int(T / dt)
res = np.zeros(noiter)

def fdot(x, t):
    return -x + t / (x + 1) ** 2

def solve_my_ODE(res, fdot, x0, T, dt):
    res[0] = x0
    noiter = int(T / dt)
    for i in range(noiter - 1):
        res[i + 1] = res[i] + dt * fdot(res[i], i * dt)
        if res[i + 1] >= 2:
            res[i + 1] -= 2
    return res

%timeit fdot(x0, T)
%timeit solve_my_ODE(res, fdot, x0, T, dt)
    ->The slowest run took 8.38 times longer than the fastest. This could mean that an intermediate result is being cached 
    ->1000000 loops, best of 3: 465 ns per loop
    ->10 loops, best of 3: 122 ms per loop

@jit(nopython=True)
def fdot(x, t):
    return -x + t / (x + 1) ** 2
%timeit fdot(x0, T)
%timeit solve_my_ODE(res, fdot, x0, T, dt)
    ->The slowest run took 106695.67 times longer than the fastest. This could mean that an intermediate result is being cached 
    ->1000000 loops, best of 3: 240 ns per loop
    ->10 loops, best of 3: 99.3 ms per loop

@jit(nopython=True)
def solve_my_ODE(res, fdot, x0, T, dt):
    res[0] = x0
    noiter = int(T / dt)
    for i in range(noiter - 1):
        res[i + 1] = res[i] + dt * fdot(res[i], i * dt)
        if res[i + 1] >= 2:
            res[i + 1] -= 2
    return res
%timeit fdot(x0, T)
%timeit solve_my_ODE(res, fdot, x0, T, dt)
    ->The slowest run took 10.21 times longer than the fastest. This could mean that an intermediate result is being cached 
    ->1000000 loops, best of 3: 274 ns per loop
    ->TypingError                               Traceback (most recent call last)
ipython-input-10-27199e82c72c> in <module>()
  1 get_ipython().magic('timeit fdot(x0, T)')
----> 2 get_ipython().magic('timeit solve_my_ODE(res, fdot, x0, T, dt)')

(...)


TypingError: Failed at nopython (nopython frontend)
Undeclared pyobject(float64, float64)
File "<ipython-input-9-112bd04325a4>", line 6

I don't understand why I got this error. My suspicion is that numba does not recognize the input field fdot (which is a python function which btw is already compiled with Numba).

Since I am so new to Numba I have several questions

  • What can I do to make Numba understand the input field fdot is a function?
  • Using JIT on the function fdot "only" leads to a decrease in 50%. Should I expect more? or is this normal?
  • Does this script look like a reasonable way to simulate an ODE with discrete jumps? Mathematically this is equivalent at solving an ODE with delta functions.

Numba version is 0.17

gota
  • 2,338
  • 4
  • 25
  • 45

2 Answers2

3

You're right in thinking that numba doesn't recognise fdot as a numba compiled function. I don't think you can make it recognise it as a function argument, but you can use this approach (using variable capture so fdot is known when the function is built) to build an ODE solver:

def make_solver(f):
    @jit(nopython=True)
    def solve_my_ODE(res, x0, T, dt):
        res[0] = x0
        noiter = int(T / dt)
        for i in range(noiter - 1):
            res[i + 1] = res[i] + dt * f(res[i], i * dt)
            if res[i + 1] >= 2:
                res[i + 1] -= 2
        return res
    return solve_my_ODE

fdot_solver = make_solver(fdot) # call this for each function you 
      # want to make an ODE solver for

Here's an alternate version which doesn't require you to pass res to it. Only the loop is accelerated, but since that's the slow bit that's the only important bit.

def make_solver_2(f):
    @jit
    def solve_my_ODE(x0, T, dt):
        # this bit ISN'T in no python mode
        noiter = int(T / dt)
        res = np.zeros(noiter)
        res[0] = x0
        # but the loop is nopython (so fast)
        for i in range(noiter - 1):
            res[i + 1] = res[i] + dt * f(res[i], i * dt)
            if res[i + 1] >= 2:
                res[i + 1] -= 2
        return res
    return solve_my_ODE

I prefer this version because it allocates the return value for you, so it's a little easier to use. That's a slight diversion from your real question though.

In terms of timings I get (in seconds, for 20 iterations):

  • 6.90394687653 (for only fdot in numba)
  • 0.0584900379181 (for version 1)
  • 0.0640540122986 (for version 2 - i.e. it's slightly slower but a little easier to use)

Thus, it's roughly 100x faster - accelerating the loop makes a big difference!

Your third question: "Does this script look like a reasonable way to simulate an ODE with discrete jumps? Mathematically this is equivalent at solving an ODE with delta functions." I really don't know. Sorry!

DavidW
  • 29,336
  • 6
  • 55
  • 86
3

To the last point:

  • In the current form, it is not even a valid implementation for a well-behaved ODE. It stops one step too early, the last "regular" step should be towards noiter*dt, and does not consider the time remainder T-noiter*dt.

    Note that range(N) generates the numbers 0,1,…,N-1. Equally, res=zeros(N) generates an array with N entries, from res[0] to res[N-1].

  • The switching should not depend on the discretization, i.e., the step length. To that effect, a more exact time for the crossing of the switching condition should be determine via interpolation (linear or reverse quadratic) and then the modified or new system restarted with the new initial conditions. To preserve the desired grid, use a short first step.


def solve_my_ODE(res, fdot, x0, T, dt):
    noiter = int(T / dt)
    dt = T/noiter          #adapt the timestep 
    res = zeros(noiter+1)
    res[0] = x0
    for i in range(noiter):
        res[i + 1] = res[i] + dt * fdot(res[i], i * dt)
        if res[i + 1] >= 2:
            h = (2-res[i])/(res[i+1]-res[i]) # precautions against zero division ?
            res[i + 1] = 0 + (1-h)*dt * fdot(0, (i+h)*dt)
    return res

  • It appears that a final accuracy of better than 1e-4 is desired. Here with dt=1e-5 the computation uses 100 000 steps and equally many function evaluations.

    Using the classical Runge-Kutta method with h=0.05 will result in an error slightly larger than 1e-5 (dt**4=6.25e-6), i.e., with a comparable size to the Euler method error. However, now this only requires T/dt=20 steps with a total of 80 function evaluations. Note that the switching time needs also be accurate with order O(dt**4) to not contaminate the global error order.

    Thus if speed is the objective, it is profitable to investigate in higher order methods.

Lutz Lehmann
  • 25,219
  • 2
  • 22
  • 51