1

I have a 3D array and a 2D array of indices. How can I select on the last axis?

import numpy as np

# example array
shape = (4,3,2)
x = np.random.uniform(0,1, shape)

# indices
idx = np.random.randint(0,shape[-1], shape[:-1])

Here is a loop that can give the desired result. But there should be an efficient vectorized way to do this.

result = np.zeros(shape[:-1])
for i in range(shape[0]):
    for j in range(shape[1]):
        result[i,j] = x[i,j,idx[i,j]]
guyguyguy12345
  • 539
  • 3
  • 11

2 Answers2

1

Correction for 2D, first use meshgrid to build a cartesian mapping.

m=np.meshgrid(range(shape[0]), range(shape[1]), indexing="ij")
results = x[m[0], m[1], idx]
Daraan
  • 1,797
  • 13
  • 24
1

A possible solution:

np.take_along_axis(x, np.expand_dims(idx, axis=-1), axis=-1).squeeze(axis=-1)

Alternatively,

i, j = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]), indexing='ij')
x[i, j, idx]
PaulS
  • 21,159
  • 2
  • 9
  • 26