This question generalises a previously asked question: Indexing a list of list of vectors with a list of indices
Given some data represented as a array of dimensions (N1,..., Nk, L, H), and a batch of indices of dimensions (N1,..., Nk), how can I index the data such that the output is of dimensions (N1,..., Nk, H). Semantically, I'd like to replace the below for loops with a single NumPy call:
N1, N2 = (3, 3, 3), ()
L, H = 5, 2
data = np.arange(np.prod(N1) * L * H).reshape(*N1, L, H)
inds = np.arange(np.prod(N1), dtype=int).reshape(N1) % L
out = np.empty((*N1, H), dtype=data.dtype)
for ii in np.ndindex(N1):
out[ii] = data[ii + (inds[ii],)]
assert out.shape == (3, 3, 3, 2)
data = np.arange(np.prod(N2) * L * H).reshape(*N2, L, H)
inds = np.arange(np.prod(N2), dtype=int).reshape(N2) % L
out = np.empty((*N2, H), dtype=data.dtype)
for ii in np.ndindex(N2):
out[ii] = data[ii + (inds[ii],)]
assert out.shape == (2,)
Seems like flattening the batch dimensions could work, but is there utilities in NumPy capable of indexing as such?