tf.argmax
returns the index of the maximum value, as per the axis specified. The specified axis will be crushed, and the index of the maximum value of every unit will be returned. The returned shape will have the same shape, except the the specified axis that will disappear. I'll make examples with tf.reduce_max
so we can follow the values.
Let's start with your array:
x = np.array([[[15, 23],
[3, 1],
[80, 56]],
[[98, 95],
[97, 82],
[10, 37]],
[[65, 32],
[25, 39],
[54, 68]]])
see tf.reduce_max(x, axis=0)
([[[15, 23],
[3, 1],
[80, 56]],
[[98, 95], ^
^ ^ [97, 82],
^ ^ [10, 37]],
[[65, 32],
[25, 39],
[54, 68]]])
^
<tf.Tensor: shape=(3, 2), dtype=int32, numpy=
array([[98, 95],
[97, 82],
[80, 68]])>
now tf.reduce_max(x, 1)
([[[15, 23], [[98, 95], [[65, 32],
^ ^ ^
[3, 1], [97, 82], [25, 39],
[80, 56]], [10, 37]], [54, 68]]])
^ ^ ^
<tf.Tensor: shape=(3, 2), dtype=int32, numpy=
array([[80, 56],
[98, 95],
[65, 68]])>
now tf.reduce_max(x, axis=2)
([[[15, 23],
^
[3, 1],
^
[80, 56]],
^
[[98, 95],
^
[97, 82],
^
[10, 37]],
^
[[65, 32],
^
[25, 39],
^
[54, 68]]])
^
<tf.Tensor: shape=(3, 3), dtype=int32, numpy=
array([[23, 3, 80],
[98, 97, 37],
[65, 39, 68]])>