Let's say I have an array with size batch
x max_len
x output_size
, where batch
, max_len
, and output_size
all correspond to positive natural numbers. I have a list of indices which correspond to individual items in dimension 1 (i.e. max_len
). How can I select from the array given these indices?
As a concrete example, let's say I have the following:
>>> l = np.random.randn(4,5,6)
>>> l.shape
(4, 5, 6)
>>> idx = [0,0,2,3]
When I select l
given idx
I get:
>>> l[:,idx,:].shape
(4, 4, 6)
>>>
I also tried np.take
but reached the same result:
>>> np.take(l,idx,axis=1).shape
(4, 4, 6)
>>>
However, the output I am looking after is (4,1,6)
as I am trying to have only one item looking at each element in the batch
(i.e. first dimension). How can I produce the output with the proper shape?