7

Is there a way to use np.newaxis with Numba nopython ? In order to apply broadcasting function without fallbacking on python ?

for example

@jit(nopython=True)
def toto():
    a = np.random.randn(20, 10)
    b = np.random.randn(20) 
    c = np.random.randn(10)
    d = a - b[:, np.newaxis] * c[np.newaxis, :]
    return d

Thanks

JoshAdel
  • 66,734
  • 27
  • 141
  • 140
EntrustName
  • 421
  • 6
  • 19

3 Answers3

9

In my casse (numba: 0.35, numpy: 1.14.0) expand_dims works fine:

import numpy as np
from numba import jit

@jit(nopython=True)
def toto():
    a = np.random.randn(20, 10)
    b = np.random.randn(20) 
    c = np.random.randn(10)
    d = a - np.expand_dims(b, -1) * np.expand_dims(c, 0)
    return d

Of course we can omit the second expand_dims using broadcasting.

DerWeh
  • 1,721
  • 1
  • 15
  • 26
7

You can accomplish this using reshape, it looks like the [:, None] indexing isn't currently supported. Note that this probably won't be much faster than doing it python, since it was already vectorized.

@jit(nopython=True)
def toto():
    a = np.random.randn(20, 10)
    b = np.random.randn(20) 
    c = np.random.randn(10)
    d = a - b.reshape((-1, 1)) * c.reshape((1,-1))
    return d
chrisb
  • 49,833
  • 8
  • 70
  • 70
  • 1
    I have tried it but I get : `reshape() supports contiguous array only`. And of course, `toto()` is an example not my actual function – EntrustName Aug 04 '16 at 11:51
  • You could do `b.copy().reshape((-1,1))`. If your array isn't contiguous I believe this would have copied anyways, though not 100% sure. – chrisb Aug 04 '16 at 12:02
1

This can be done with the newest version of Numba (0.27) and numpy stride_tricks. You need to be careful with this and it's a bit ugly. Read the docstring for as_strided to make sure you understand what's going on since this isn't "safe" since it doesn't check the shape or the strides.

import numpy as np
import numba as nb

a = np.random.randn(20, 10)
b = np.random.randn(20) 
c = np.random.randn(10)

def toto(a, b, c):

    d = a - b[:, np.newaxis] * c[np.newaxis, :]
    return d

@nb.jit(nopython=True)
def toto2(a, b, c):
    _b = np.lib.stride_tricks.as_strided(b, shape=(b.shape[0], 1), strides=(b.strides[0], 0))
    _c = np.lib.stride_tricks.as_strided(c, shape=(1, c.shape[0]), strides=(0, c.strides[0]))
    d = a - _b * _c

    return d

x = toto(a,b,c)
y = toto2(a,b,c)
print np.allclose(x, y) # True
JoshAdel
  • 66,734
  • 27
  • 141
  • 140