19

I've been working on speeding up a resampling calculation for a particle filter. As python has many ways to speed it up, I though I'd try them all. Unfortunately, the numba version is incredibly slow. As Numba should result in a speed up, I assume this is an error on my part.

I tried 4 different versions:

  1. Numba
  2. Python
  3. Numpy
  4. Cython

The code for each is below:

import numpy as np
import scipy as sp
import numba as nb
from cython_resample import cython_resample

@nb.autojit
def numba_resample(qs, xs, rands):
    n = qs.shape[0]
    lookup = np.cumsum(qs)
    results = np.empty(n)

    for j in range(n):
        for i in range(n):
            if rands[j] < lookup[i]:
                results[j] = xs[i]
                break
    return results

def python_resample(qs, xs, rands):
    n = qs.shape[0]
    lookup = np.cumsum(qs)
    results = np.empty(n)

    for j in range(n):
        for i in range(n):
            if rands[j] < lookup[i]:
                results[j] = xs[i]
                break
    return results

def numpy_resample(qs, xs, rands):
    results = np.empty_like(qs)
    lookup = sp.cumsum(qs)
    for j, key in enumerate(rands):
        i = sp.argmax(lookup>key)
        results[j] = xs[i]
    return results

#The following is the code for the cython module. It was compiled in a
#separate file, but is included here to aid in the question.
"""
import numpy as np
cimport numpy as np
cimport cython

DTYPE = np.float64

ctypedef np.float64_t DTYPE_t

@cython.boundscheck(False)
def cython_resample(np.ndarray[DTYPE_t, ndim=1] qs, 
             np.ndarray[DTYPE_t, ndim=1] xs, 
             np.ndarray[DTYPE_t, ndim=1] rands):
    if qs.shape[0] != xs.shape[0] or qs.shape[0] != rands.shape[0]:
        raise ValueError("Arrays must have same shape")
    assert qs.dtype == xs.dtype == rands.dtype == DTYPE

    cdef unsigned int n = qs.shape[0]
    cdef unsigned int i, j 
    cdef np.ndarray[DTYPE_t, ndim=1] lookup = np.cumsum(qs)
    cdef np.ndarray[DTYPE_t, ndim=1] results = np.zeros(n, dtype=DTYPE)

    for j in range(n):
        for i in range(n):
            if rands[j] < lookup[i]:
                results[j] = xs[i]
                break
    return results
"""

if __name__ == '__main__':
    n = 100
    xs = np.arange(n, dtype=np.float64)
    qs = np.array([1.0/n,]*n)
    rands = np.random.rand(n)

    print "Timing Numba Function:"
    %timeit numba_resample(qs, xs, rands)
    print "Timing Python Function:"
    %timeit python_resample(qs, xs, rands)
    print "Timing Numpy Function:"
    %timeit numpy_resample(qs, xs, rands)
    print "Timing Cython Function:"
    %timeit cython_resample(qs, xs, rands)

This results in the following output:

Timing Numba Function:
1 loops, best of 3: 8.23 ms per loop
Timing Python Function:
100 loops, best of 3: 2.48 ms per loop
Timing Numpy Function:
1000 loops, best of 3: 793 µs per loop
Timing Cython Function:
10000 loops, best of 3: 25 µs per loop

Any idea why the numba code is so slow? I assumed it would be at least comparable to Numpy.

Note: if anyone has any ideas on how to speed up either the Numpy or Cython code samples, that would be nice too:) My main question is about Numba though.

jiminy_crist
  • 2,395
  • 2
  • 17
  • 23
  • I think a better place for this would be http://codereview.stackexchange.com/ – kylieCatt Jan 30 '14 at 21:53
  • 1
    try it with a much larger list ? – Joran Beasley Jan 30 '14 at 21:58
  • 2
    @IanAuld: Perhaps, but as others have gotten substantial speed ups from numba, I figure it's that I'm using it wrong, rather than a mere profiling issue. This seems to me to fit stackoverflow's intended use. – jiminy_crist Jan 30 '14 at 22:02
  • @JoranBeasley: I tried it with 1000, and 10000 points. Numba took 773 ms to run with 1000, compared to 234 ms with pure python. The 10000 point trial is still running... – jiminy_crist Jan 30 '14 at 22:10
  • do you have a reasonable gpu? thats compatible with numba? (Im not sure what the requirements are) as an aside I was just watching this http://www.youtube.com/watch?v=iYAG6I433gQ this morning ... it may shed some light (it may not) – Joran Beasley Jan 30 '14 at 22:37
  • I'm just using regular numba, so it runs on the cpu alone. Still, should get some speedup. I read through the user guide here: http://numba.pydata.org/numba-doc/0.11/userguide.html, and everything looks right to me. Like cython, it should be faster to use loops than to rely on numpys vectorized routines. – jiminy_crist Jan 30 '14 at 22:52
  • 1
    As a note `argmax` can take an axis argument, so you can broadcast `rands` and `lookup` against each other to make a `n x n` matrix for an N^2 scaling algorithm. Alternatively you can use searchsorted which will have (should have?) Nlog(N) scaling. – Daniel Jan 31 '14 at 01:06

2 Answers2

24

The problem is that numba can't intuit the type of lookup. If you put a print nb.typeof(lookup) in your method, you'll see that numba is treating it as an object, which is slow. Normally I would just define the type of lookup in a locals dict, but I was getting a strange error. Instead I just created a little wrapper, so that I could explicitly define the input and output types.

@nb.jit(nb.f8[:](nb.f8[:]))
def numba_cumsum(x):
    return np.cumsum(x)

@nb.autojit
def numba_resample2(qs, xs, rands):
    n = qs.shape[0]
    #lookup = np.cumsum(qs)
    lookup = numba_cumsum(qs)
    results = np.empty(n)

    for j in range(n):
        for i in range(n):
            if rands[j] < lookup[i]:
                results[j] = xs[i]
                break
    return results

Then my timings are:

print "Timing Numba Function:"
%timeit numba_resample(qs, xs, rands)

print "Timing Revised Numba Function:"
%timeit numba_resample2(qs, xs, rands)

Timing Numba Function:
100 loops, best of 3: 8.1 ms per loop
Timing Revised Numba Function:
100000 loops, best of 3: 15.3 µs per loop

You can go even a little faster still if you use jit instead of autojit:

@nb.jit(nb.f8[:](nb.f8[:], nb.f8[:], nb.f8[:]))

For me that lowers it from 15.3 microseconds to 12.5 microseconds, but it's still impressive how well autojit does.

JoshAdel
  • 66,734
  • 27
  • 141
  • 140
  • Yep, that fixed it! I tried playing around with unrolling the loop on the numba_cumsum function, and jit-ing that as well, but it either ran slower, or failed to compile. Looks like this is about as fast as it can go. What's odd to me is that the numba version now runs consistently ~twice as fast as the cython code. As they are both compiled, I find this odd. Thoughts? – jiminy_crist Feb 01 '14 at 00:31
  • @jammycrisp - I also tried hand-coding the cumsum and I found it to be marginally slower than calling out to numpy. As far as differences between cython and numba, it could perhaps be related to whatever c compiler you're using vs llvm. What compiler are you using? Are you specifying any optimization flags in your `setup.py`? – JoshAdel Feb 01 '14 at 00:47
  • I'm using GCC 4.6.3. I didn't know you could add compiler flags to setup.py, but after figuring it out I compiled with -O3, and it didn't seem to change anything. – jiminy_crist Feb 01 '14 at 01:51
3

Faster numpy version (10x speedup compared to numpy_resample)

def numpy_faster(qs, xs, rands):
    lookup = np.cumsum(qs)
    mm = lookup[None,:]>rands[:,None]
    I = np.argmax(mm,1)
    return xs[I]
hpaulj
  • 221,503
  • 14
  • 230
  • 353
  • Thanks. I figured there was a way to do this, but didn't look too much into it before just skipping to cython. For n=100 I only get a 2x speedup from the old numpy function using this, but it's good to know. Still curious why my numba code doesn't work though. – jiminy_crist Jan 31 '14 at 06:58