4

Question: How to index numpy array with given indexes?

Discription

In reinforcement learning, I got many discrete distributions corresponding to different states, like the following:

import numpy as np
distributions = np.array([[0.1,0.2,0.7],[0.3,0.3,0.4],[0.2,0.2,0.6]])

# array([[0.1, 0.2, 0.7],  # \pi(s0)
#        [0.3, 0.3, 0.4],  # \pi(s1)
#        [0.2, 0.2, 0.6]]) # \pi(s2)

Then, I want to get the probabilities of taking action 0 in state s0, taking action 2 in state s1, and taking action 1 in state s2 respectively.

So I stored the index value in a array like the following:

actions = np.array([[0],[2],[1]])

# array([[0],  # taking action 0 in state s0
#        [2],  # taking action 2 in state s1
#        [1]]) # taking action 1 in state s2

What I expected to get.

I want to index distributions using actions, and expect to get the the result like:

# array([0.1,0.4,0.2])
# or 
# array([[0.1],
#        [0.4],
#        [0.2]])

What I tried.

I've tried np.take(distributions, actions), but the retun array([0.1, 0.7, 0.2]) was obviously what I wanted. And I also tried distributions[:,actions], which gave me another wrong answer as bellow:

array([[0.1, 0.7, 0.2],
       [0.3, 0.4, 0.3],
       [0.2, 0.6, 0.2]])         

Question

What can I do to solve this problem?

Jack Huang
  • 43
  • 5

1 Answers1

3
In [614]: distributions = np.array([[0.1,0.2,0.7],[0.3,0.3,0.4],[0.2,0.2,0.6]]) 
     ...:                                                                       
In [615]: actions = np.array([[0],[2],[1]])  

Use a [0,1,2] row index:

In [616]: distributions[np.arange(3), actions]                                  
Out[616]: 
array([[0.1, 0.3, 0.2],
       [0.7, 0.4, 0.6],
       [0.2, 0.3, 0.2]])

oops, actions is (3,1) shape, which broadcasts with (3,) to produce a (3,3) selection. Instead we want to use a (3,) shaped actions:

In [617]: distributions[np.arange(3), actions.ravel()]                          
Out[617]: array([0.1, 0.4, 0.2])

or to get a (3,1) result:

In [619]: distributions[[[0],[1],[2]], actions]                                 
Out[619]: 
array([[0.1],
       [0.4],
       [0.2]])
hpaulj
  • 221,503
  • 14
  • 230
  • 353
  • 1
    This is a perfect answer! So, the code is: ```distributions[[[0],[1],[2]],[[0],[2],[1]]]``` or ```distributions[[0,1,2],[0,2,1]]```. – jiadong Nov 15 '19 at 07:58