0

I hope to modify a NumPy array element in a parallel computing, by using njit as follows.

def get_indice():
    ....
    return new_indice

@njit()
def myfunction(arr):
    for i in prange(100):
       indices = get_indice()
       arr[indices] += 1

But it seems that modifying the array element is unsafe due to the race condition. So I am thinking to make a list of the indices, and modify the arr later, something like

@njit()
def myfunction(arr):
    index_list = []
    for i in orange(100):
        indices = get_indice()
        index_list.append(indices)
    arr[index_list] += 1

But the list.append() is also not safe as stated here. Obviously, I don't mind the order in the list. Is there any way to work around for this problem?

Xudong
  • 441
  • 5
  • 16

1 Answers1

1

You can simply store the indices using direct indexing. It is faster and thread-safe in this case (assuming get_indice is also thread-safe). Note that AFAIK arr[index_list] += 1 assume an index cannot appear twice (and if this is not the case, then the first solution should actually work). A loop can fix that and might actually be faster. Here is an example of fixed code:

@njit()
def myfunction(arr):
    index_list = np.empty(100, dtype=np.int64)
    for i in prange(100):
        index_list[i] = get_indice()
    for i in range(100):
        arr[index_list[i]] += 1

Note that parallel=True is needed for the prange to actually be useful (ie. use multiple threads). Also note that multi-threading introduce a significant overhead so using prange can actually be slower if get_indice() is not pretty slow.

Jérôme Richard
  • 41,678
  • 6
  • 29
  • 59
  • Could you explain why the first solution will work if I get repeated indices? I thought the race condition will overwrite the arr at some point. – Xudong Jun 19 '22 at 21:32
  • 1
    There is a race condition only if `get_indice()` is not a subset of a permutation (ie. not all elements are unique). So, no, it will not work with repeated indices. I just pointed out that `arr[index_list] += 1` implicitly assume that there is no repeated elements (so there was possibly two error or nor error at all regarding the actual input data). – Jérôme Richard Jun 19 '22 at 21:48