Suppose I have the following DataFrame Q_df
:
(0, 0) (0, 1) (0, 2) (1, 0) (1, 1) (1, 2) (2, 0) (2, 1) (2, 2)
(0, 0) 0.000 0.00 0.0 0.64 0.000 0.0 0.512 0.000 0.0
(0, 1) 0.000 0.00 0.8 0.00 0.512 0.0 0.000 0.512 0.0
(0, 2) 0.000 0.64 0.0 0.00 0.000 0.8 0.000 0.000 1.0
(1, 0) 0.512 0.00 0.0 0.00 0.000 0.8 0.512 0.000 0.0
(1, 1) 0.000 0.64 0.0 0.00 0.000 0.0 0.000 0.512 0.0
(1, 2) 0.000 0.00 0.8 0.64 0.000 0.0 0.000 0.000 1.0
(2, 0) 0.512 0.00 0.0 0.64 0.000 0.0 0.000 0.512 0.0
(2, 1) 0.000 0.64 0.0 0.00 0.512 0.0 0.512 0.000 0.0
(2, 2) 0.000 0.00 0.8 0.00 0.000 0.8 0.000 0.000 0.0
which is generated using the following code:
import numpy as np
import pandas as pd
states = list(itertools.product(range(3), repeat=2))
Q = np.array([[0.000,0.000,0.000,0.640,0.000,0.000,0.512,0.000,0.000],
[0.000,0.000,0.800,0.000,0.512,0.000,0.000,0.512,0.000],
[0.000,0.640,0.000,0.000,0.000,0.800,0.000,0.000,1.000],
[0.512,0.000,0.000,0.000,0.000,0.800,0.512,0.000,0.000],
[0.000,0.640,0.000,0.000,0.000,0.000,0.000,0.512,0.000],
[0.000,0.000,0.800,0.640,0.000,0.000,0.000,0.000,1.000],
[0.512,0.000,0.000,0.640,0.000,0.000,0.000,0.512,0.000],
[0.000,0.640,0.000,0.000,0.512,0.000,0.512,0.000,0.000],
[0.000,0.000,0.800,0.000,0.000,0.800,0.000,0.000,0.000]])
Q_df = pd.DataFrame(index=states, columns=states, data=Q)
For each row of Q, I would like to get the column name corresponding to the maximum value in the row. If I try
policy = Q_df.idxmax()
then the resulting Series looks like this:
(0, 0) (1, 0)
(0, 1) (0, 2)
(0, 2) (0, 1)
(1, 0) (0, 0)
(1, 1) (0, 1)
(1, 2) (0, 2)
(2, 0) (0, 0)
(2, 1) (0, 1)
(2, 2) (0, 2)
The first row looks OK: the maximum element of the first row is 0.64
and occurs in column (1,0)
. So does the second. For the third row, however, the maximum element is 0.8
and occurs in column (1,2)
, so I would expect the corresponding value in policy
to be (1,2)
, not (0,1)
.
Any idea what is going wrong here?