1

Let's say I have an array with size batch x max_len x output_size, where batch, max_len, and output_size all correspond to positive natural numbers. I have a list of indices which correspond to individual items in dimension 1 (i.e. max_len). How can I select from the array given these indices?

As a concrete example, let's say I have the following:

>>> l = np.random.randn(4,5,6)
>>> l.shape
(4, 5, 6)
>>> idx = [0,0,2,3]

When I select l given idx I get:

>>> l[:,idx,:].shape
(4, 4, 6)
>>>

I also tried np.take but reached the same result:

>>> np.take(l,idx,axis=1).shape
(4, 4, 6)
>>> 

However, the output I am looking after is (4,1,6) as I am trying to have only one item looking at each element in the batch (i.e. first dimension). How can I produce the output with the proper shape?

Clement Attlee
  • 723
  • 3
  • 8
  • 16

1 Answers1

2

Use np.take_along_axis after extending idx to have same ndims as l -

np.take_along_axis(l,np.asarray(idx)[:,None,None],axis=1)

With explicit integer-array indexing -

l[np.arange(len(idx)),idx][:,None] # skip [:,None] for (4,6) shaped o/p
Divakar
  • 218,885
  • 19
  • 262
  • 358