0

Suppose I have the following DataFrame Q_df:

        (0, 0)  (0, 1)  (0, 2)  (1, 0)  (1, 1)  (1, 2)  (2, 0)  (2, 1)  (2, 2)
(0, 0)   0.000    0.00     0.0    0.64   0.000     0.0   0.512   0.000     0.0
(0, 1)   0.000    0.00     0.8    0.00   0.512     0.0   0.000   0.512     0.0
(0, 2)   0.000    0.64     0.0    0.00   0.000     0.8   0.000   0.000     1.0
(1, 0)   0.512    0.00     0.0    0.00   0.000     0.8   0.512   0.000     0.0
(1, 1)   0.000    0.64     0.0    0.00   0.000     0.0   0.000   0.512     0.0
(1, 2)   0.000    0.00     0.8    0.64   0.000     0.0   0.000   0.000     1.0
(2, 0)   0.512    0.00     0.0    0.64   0.000     0.0   0.000   0.512     0.0
(2, 1)   0.000    0.64     0.0    0.00   0.512     0.0   0.512   0.000     0.0
(2, 2)   0.000    0.00     0.8    0.00   0.000     0.8   0.000   0.000     0.0

which is generated using the following code:

import numpy as np
import pandas as pd

states = list(itertools.product(range(3), repeat=2))

Q = np.array([[0.000,0.000,0.000,0.640,0.000,0.000,0.512,0.000,0.000],
[0.000,0.000,0.800,0.000,0.512,0.000,0.000,0.512,0.000],
[0.000,0.640,0.000,0.000,0.000,0.800,0.000,0.000,1.000],
[0.512,0.000,0.000,0.000,0.000,0.800,0.512,0.000,0.000],
[0.000,0.640,0.000,0.000,0.000,0.000,0.000,0.512,0.000],
[0.000,0.000,0.800,0.640,0.000,0.000,0.000,0.000,1.000],
[0.512,0.000,0.000,0.640,0.000,0.000,0.000,0.512,0.000],
[0.000,0.640,0.000,0.000,0.512,0.000,0.512,0.000,0.000],
[0.000,0.000,0.800,0.000,0.000,0.800,0.000,0.000,0.000]])

Q_df = pd.DataFrame(index=states, columns=states, data=Q)

For each row of Q, I would like to get the column name corresponding to the maximum value in the row. If I try

policy = Q_df.idxmax()

then the resulting Series looks like this:

(0, 0)    (1, 0)
(0, 1)    (0, 2)
(0, 2)    (0, 1)
(1, 0)    (0, 0)
(1, 1)    (0, 1)
(1, 2)    (0, 2)
(2, 0)    (0, 0)
(2, 1)    (0, 1)
(2, 2)    (0, 2)

The first row looks OK: the maximum element of the first row is 0.64 and occurs in column (1,0). So does the second. For the third row, however, the maximum element is 0.8 and occurs in column (1,2), so I would expect the corresponding value in policy to be (1,2), not (0,1).

Any idea what is going wrong here?

smci
  • 32,567
  • 20
  • 113
  • 146
Kurt Peek
  • 52,165
  • 91
  • 301
  • 526

1 Answers1

3

IIUC, you can use axis=1 in idxmax:

policy = Q_df.idxmax(axis=1)

(0, 0)    (1, 0)
(0, 1)    (0, 2)
(0, 2)    (2, 2)
(1, 0)    (1, 2)
(1, 1)    (0, 1)
(1, 2)    (2, 2)
(2, 0)    (1, 0)
(2, 1)    (0, 1)
(2, 2)    (0, 2)
dtype: object
Nickil Maveli
  • 29,155
  • 8
  • 82
  • 85