I have three questions:
- Is there any way to extract maximum values and its indices from a 2D tensor?
Suppose we have a tensor x = tf.constant([[0, 2, 1], [0, 0, 8], [2, 9, 0]])
, the desired output would be max = 9, index = [2,1]
.
I've tried the function tf.argmax
, however the argument axis
of tf.argmax
doesn't accept tuple
type, for exmple: it doesn't works when axis = (0, 1)
, only when axis = 0
or axis = 1
.
I have referred to this question but they did not answer my question above.
Also, how can I find a number of the maximum value (let's say the first 5 maximum value) of a 2D tensor, since
tf.reduce_max
only returns 1?And after having the indices, how can I use these indices to index the value from another tensor
y
that has the same size ofx
above?
Update 1: As Aldream pointed out that the first two of my question is a duplicate of this question How to find the top k values in a 2-D tensor in tensorflow. Still, the third question remains.
Update 2: For the third question, if anyone has the same problem, we can use tf.gather
.