0

I want to know how to use tf.argmax in 3D array.

My input data is like that:

[[[0, -1, 5, 2, 1], [2, 2, 3, 2, 5], [6, 1, 2, 4, -1]], 
[[-1, -2, 3, 2, 1], [0, 3, 2, 7, -1], [-1, 5, 2, 1, 3]]]

And I want to get the output of argmax by this input data like this:

[[2, 4, 0], [2, 3, 1]]

And I want to use softmax_cross_entropy_with_logitsfunction in this format.

How should I use tf.nn.softmax_cross_entropy_with_logits function and tf.equal(tf.argmax) and tf.reduce_mean(tf.cast) ?

Giuseppe Marra
  • 1,094
  • 7
  • 16
Gi Yeon Shin
  • 357
  • 2
  • 7
  • 19

1 Answers1

0

You can use tf.argmax along axis=3

a = tf.constant([[[0, -1, 5, 2, 1], [2, 2, 3, 2, 5], [6, 1, 2, 4, -1]], 
        [[-1, -2, 3, 2, 1], [0, 3, 2, 7, -1], [-1, 5, 2, 1, 3]]])
b = tf.argmax(a, axis=2)
Ishant Mrinal
  • 4,898
  • 3
  • 29
  • 47