I'm trying to obtain all positions of a sub-list of elements taken from a big list.
In Python, using numpy, say I have
from datetime import datetime as dt
import numpy as np
from numba import jit, int64
n, N = 20, 120000
int_vocabulary = np.array(range(N))
np.random.shuffle(int_vocabulary) # to make the problem non-trivial
int_sequence = np.random.choice(int_vocabulary, n, replace=False)
and I want to get all positions the integers taken in int_sequence
have in the big sequence int_vocabulary
. I'm interested in fast computation.
So far I've tried using numba brute force research, numpy mask approach, list comprehension brute force (for baseline), and list comprehension and numpy mask mixing.
@jit(int64[:](int64[:], int64[:], int64, int64))
def check(int_sequence, int_vocabulary, n, N):
all_indices = np.full(n, N)
for xi in range(n):
for i in range(N):
if int_sequence[xi] == int_vocabulary[i]:
all_indices[xi] = i
return all_indices
t0 = dt.now()
for _ in range(10):
all_indices0 = check(int_sequence, int_vocabulary, n, N)
t0 = (dt.now() - t0).total_seconds()
print("numba : ", t0)
t0 = dt.now()
for _ in range(10):
mask = np.full(len(int_vocabulary), False)
for x in int_sequence:
mask += int_vocabulary == x
all_indices1 = np.flatnonzero(mask)
t0 = (dt.now() - t0).total_seconds()
print("numpy :", t0)
t0 = dt.now()
for _ in range(10):
all_indices2 = np.array([i for i, x in enumerate(int_vocabulary)
if x in int_sequence])
t0 = (dt.now() - t0).total_seconds()
print("list comprehension : ", t0)
t0 = dt.now()
for _ in range(10):
mask = np.sum(np.array([int_vocabulary == x for x in int_sequence]), axis=0)
all_indices3 = np.flatnonzero(mask)
t0 = (dt.now() - t0).total_seconds()
print("mixed numpy + list comprehension : ", t0)
assert np.sum(all_indices0) == np.sum(all_indices1)
assert np.sum(all_indices1) == np.sum(all_indices2)
assert np.sum(all_indices2) == np.sum(all_indices3)
each time I do the calculation 10 times to get comparable statistics. The outcome is
numba : 0.028039
numpy : 0.011616
list comprehension : 3.116753
mixed numpy + list comprehension : 0.032301
I nevertheless wonder whether there are faster algorithm for this problem.