0

I want to njit a function with numba, where pnt_group_ids_ can be in two types, np.int64 or np.int64[::1]. :

import numpy as np
import numba as nb

sorted_fb_i = np.array([1, 3, 4, 2, 5], np.int64)
fb_groups_ids = nb.typed.List([np.array([4, 2], np.int64), np.array([1, 3, 5], np.int64)])
moved_fb_group_ids = nb.typed.List.empty_list(nb.types.Array(dtype=nb.int64, ndim=1, layout="C"))
ind = 0

@nb.njit
def points_group_ids(sorted_fb_i, fb_groups_ids, moved_fb_group_ids, ind):
    pnt_group_ids_ = sorted_fb_i[ind]
    for i in range(len(fb_groups_ids)):
        if sorted_fb_i[ind] in fb_groups_ids[i]:
            pnt_group_ids_ = fb_groups_ids[i]
            moved_fb_group_ids.append(fb_groups_ids.pop(i))
            break
    return pnt_group_ids_, fb_groups_ids, moved_fb_group_ids

which will get error:

Cannot unify array(int64, 1d, C) and int64 for 'pnt_group_ids_.2'

Is there any way to write signature for that which can handle the both types, something like:

((int64, int64[::1]), ListType(int64[::1]), ListType(int64[::1]))(int64[::1], ListType(int64[::1]), ListType(int64[::1]), int64)

If it could not to be handled by signatures, the related line can be substituted by:

pnt_group_ids_ = np.array([sorted_fb_i[ind]], np.int64)

Which will work. But, how to signature this when we have multiple inputs and multiple outputs? Now, it will get the following error when we use such the above signature with just one type:

TypeError: 'tuple' object is not callable

This function will be called in a loop, so moved_fb_group_ids, which was an empty numba list and should have been typed otherwise it get error, will be filled and fb_groups_ids becomes empty; Does emptyness of fb_groups_ids will cause the code to get error?

The main goal of this question was about how to write signatures (for both inputs and outputs besides each other) for this function (I know that it is recommended to let numba find them), when we have multiple input and multiple output (preferring signature that can handle both types without changing the code, if it be possible).

Ali_Sh
  • 2,667
  • 3
  • 43
  • 66

1 Answers1

1

as a single number can be an array with 1 element, a simple solution is to just convert your single number to an array.

pnt_group_ids_ = sorted_fb_i[ind:ind+1]
@nb.njit("Tuple((int64[::1],ListType(int64[::1]),ListType(int64[::1])))(int64[::1], ListType(int64[::1]), ListType(int64[::1]), int64)")
def points_group_ids(sorted_fb_i, fb_groups_ids, moved_fb_group_ids, ind):
    pnt_group_ids_ = sorted_fb_i[ind:ind+1]
    for i in range(len(fb_groups_ids)):
        if sorted_fb_i[ind] in fb_groups_ids[i]:
            pnt_group_ids_ = fb_groups_ids[i]
            moved_fb_group_ids.append(fb_groups_ids.pop(i))
            break
    return pnt_group_ids_, fb_groups_ids, moved_fb_group_ids

and it works .... without any context of what this is for ...

Ahmed AEK
  • 8,584
  • 2
  • 7
  • 23
  • 1
    @Ali_Sh that's not possible, numba cannot generate llvm code for it, think about how C++ and rust work, they both use llvm, and it only allows one return type, as the code will be statically typed, not dynamically typed like python, there can only be 1 return type and it must be known beforehand. – Ahmed AEK Oct 12 '22 at 19:58
  • @Ali_Sh i have edited the answer to add the signature, you can have multiple return types for multiple input types , ie: if input is a number then output is also a number, and llvm will compile 2 versions of the function for it, but deducting the return type inside the function is not possible. – Ahmed AEK Oct 12 '22 at 20:21