0

I would like to get the indices of maximum values.

Eg:

[
 [
  [0.1 0.3 0.6],
  [0.0 0.4 0.1]
               ],
 [
  [0.9 0.2 0.6],
  [0.8 0.1 0.5]
               ]
                ]

I would like to get [[0,0,2], [0,1,1], [1,0,0], [1,1,0]]. How do I do that in the easiest way in Tensorflow?

Ynjxsjmh
  • 28,441
  • 6
  • 34
  • 52
  • Does this answer your question? [TensorFlow: Max of a tensor along an axis](https://stackoverflow.com/questions/34987509/tensorflow-max-of-a-tensor-along-an-axis) – Tom McLean Oct 23 '22 at 15:58

2 Answers2

2

You can take advantage of TF's broadcast in the last dimension

a = tf.constant([[[0.1, 0.3, 0.6],[0.0, 0.4, 0.1]],[[0.9, 0.2, 0.6],[0.8, 0.1, 0.5]]])
b = tf.reduce_max(a, -1, keepdims=True)
tf.where(a == b)

Output

<tf.Tensor: shape=(4, 3), dtype=int64, numpy=
array([[0, 0, 2],
       [0, 1, 1],
       [1, 0, 0],
       [1, 1, 0]], dtype=int64)>

In case of multiple max values per row and you only want to keep index of the first, you can derive which segment each row in the result corresponds to, then do a segment_min to get the first index in each segment.

a = tf.constant([[[0.1, 0.6, 0.6],[0.0, 0.4, 0.1]],[[0.9, 0.2, 0.6],[0.8, 0.1, 0.5]]])
b = tf.reduce_max(a, -1, keepdims=True)
c = tf.cast(tf.where(a == b), tf.int32)
d = tf.reduce_sum(tf.math.cumprod(a.shape[:-1], reverse=True, exclusive=True) * c[:,:-1], axis=1)
tf.math.segment_min(c,d)

Output

<tf.Tensor: shape=(4, 3), dtype=int32, numpy=
array([[0, 0, 1],
       [0, 1, 1],
       [1, 0, 0],
       [1, 1, 0]])>
bui
  • 1,576
  • 1
  • 7
  • 10
  • Nice, thanks! If the tensor was a = tf.constant([[[0.1, 0.6, 0.6],[0.0, 0.4, 0.1]],[[0.9, 0.2, 0.6],[0.8, 0.1, 0.5]]]) instead (so there are two maximum values in the first) it will be: Is it possible to get only one value per row with this solution in a nice way? – estrella1995 Oct 24 '22 at 06:32
  • I updated my answer to address your question. Unfortunately, TF does not implement numpy's `ravel_multi_index` so I had to do it manually to derive `d`. – bui Oct 24 '22 at 08:05
0
#argmax will give the index but not in the format you want

max_index = tf.reshape(tf.math.argmax(a, -1),(-1, 1))
max_index

 <tf.Tensor: shape=(4, 1), dtype=int64, numpy=
array([[2],
       [1],
       [0],
       [0]])>

#Format output

idx_axis =tf.reshape(tf.Variable(np.indices((a.shape[0],a.shape[1])).transpose(1,2,0)), (-1,a.shape[1]))
idx_axis
    <tf.Tensor: shape=(4, 2), dtype=int64, numpy=
array([[0, 0],
       [0, 1],
       [1, 0],
       [1, 1]])>

tf.concat([idx_axis,max_index], axis=1)
<tf.Tensor: shape=(4, 3), dtype=int64, numpy=
array([[0, 0, 2],
       [0, 1, 1],
       [1, 0, 0],
       [1, 1, 0]])>
Vijay Mariappan
  • 16,921
  • 3
  • 40
  • 59