I am recently making a project based on tensorflow CNN, MNIST dataset with a server interface.
At the predict part, I use tf.argmax() to get the largest logit, which will be the predicted value. However, the value it returns didn't seems like the correct answer.
The predict function is about like this:
self.img = tf.reshape(tf.image.convert_image_dtype(img, tf.float32), shape=[1, 28, 28, 1])
self._create_model()
saver = tf.train.Saver()
ckpt = tf.train.get_checkpoint_state('../checkpoints/')
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
saver.restore(sess, ckpt.model_checkpoint_path)
pred = tf.nn.softmax(self.logits)
prediction = tf.argmax(pred, 1)
logit = sess.run(pred)
result = sess.run(prediction)[0]
print(logit)
print(result)
return result
And the results are:
127.0.0.1 - - [19/Apr/2018 21:35:47] "POST /index.html HTTP/1.1" 200 -
[[ 0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]]
1
As you can see, the logits shows that the index with the maximum number is 5, but tf.argmax() gave me 1 instead.
By the way, my model is the basic MNIST CNN model as you can see in the link.
So what happened to this tf.argmax() function, or there's something wrong in my code?