0

Suppose I have a numpy 2d array (m by n), I want to get indexes of all its rows contain at least one nan values.

It is relatively straightforward to do it in pure numpy as follows:

import numpy as np

X = np.array([[1, 2], [3, np.nan], [6, 9]])

has_nan_idx = np.isnan(X).any(axis=1)

has_nan_idx
>>> array([False, True, False])

How can I achieve the same using numba njit? For me, I got an error since numba does not support any with arguments.

Thanks.

Keptain
  • 147
  • 7

1 Answers1

1

If you use guvectorize you'll automatically get the ufunc benefits of having things like the axis keyword.

For example:

from numba import guvectorize

@guvectorize(["void(float64[:], boolean[:])"], "(n)->()")
def isnan_with_axis(x, out):
    
    n = x.size
    out[0] = False
    
    for i in range(n):
        if np.isnan(x[i]):
            out[0] = True
            break

isnan_with_axis(X, axis=1)
# array([False,  True, False])
Rutger Kassies
  • 61,630
  • 17
  • 112
  • 97
  • As a follow-up question, it seems like I cannot use it inside another function decorated with@njit. Got an error - Untyped global name 'isnan_with_axis': Cannot determine Numba type of . How can I resolve such an error? – Keptain Feb 27 '23 at 16:34
  • That's correct, but from within another function you can write it similar with njit, and hardcode the 1D assumption. The guvectorize IMO is mainly convenient as the outer one, since thats where you'll want things like the axis keyword. – Rutger Kassies Feb 27 '23 at 16:54