12

Has anyone succeeded in speeding up scikit-learn models using numba and jit compilaition. The specific models I am looking at are regression models such as Logistic Regressions.

I am able to use numba to optimize the functions I write using sklearn models, but the model functions themselves are not affected by this and are not optimized, thus not providing a notable increase in speed. Is there are way to optimize the sklearn functions?

Any info about this would be much appreciated.

mzm
  • 373
  • 1
  • 6
  • 16

2 Answers2

9

Scikit-learn makes heavy use of numpy, most of which is written in C and already compiled (hence not eligible for JIT optimization).

Further, the LogisticRegression model is essentially LinearSVC with the appropriate loss function. I could be slightly wrong about that, but in any case, it uses LIBLINEAR to do the solving, which is again a compiled C library.

The makers of scikit-learn also make heavy use of one of the python-to-compiled systems, Pyrex I think, which again results in optimized machine compiled code ineligible for JIT compilation.

Andreus
  • 2,437
  • 14
  • 22
  • Thank you for you answer. What can one use as alternative then? tensorflow? – Xiaoxiong Lin Feb 27 '20 at 16:35
  • In general, the scikit-learn models are very performant on multi-core systems. If you require the usage of GPUs to meet your performance, then yes, you must switch to something like TensorFlow and utilize GPU-optimized algorithms. – Andreus Mar 19 '21 at 16:01
  • 2
    The answer implies numba does not optimize numpy-heavy code which is not true. – Keto Feb 24 '22 at 05:39
  • I agree with @Keto. The docs say "numba [...] works best on code that uses NumPy arrays and functions, and loops." – Ali Pardhan May 31 '22 at 01:48
1

The @numba.vectorize for arrays and @numba.guvectorise for matricies, are decorators may help since they work to combine loop operations. They generate so called "ufunc"s which achieve this goal, but instead of having to manually write the c code yourself it generates it from the python input.

See: http://numba.pydata.org/numba-doc/dev/user/vectorize.html

joshring
  • 11
  • 1