5

I am currently trying to train Google's sketch recognition model, just the one in the link: Github. But I recently encountered problems that have been bothering me for a long time.

The problem is as follows: I have used the code in the link and the data from quickdraw to complete the training. I now have a trained model with three files(.meta,.index,.data), now I want to calculate the confusion matrix for the trained model of 345 categories. But since I have never used the "estimator" of tensorflow, I don't know how to load my trained model files into the code and test it (no training), and how to get the classification score after the softmax layer (used to calculate Confusion matrix).

The ‘estimator’ API really confused me for a long time. Please solve my problem under the code in link:

def create_estimator_and_specs(run_config):
    """Creates an Experiment configuration based on the estimator and input fn."""
    model_params = tf.contrib.training.HParams(
        num_layers=FLAGS.num_layers,
        num_nodes=FLAGS.num_nodes,
        batch_size=FLAGS.batch_size,
        num_conv=ast.literal_eval(FLAGS.num_conv),
        conv_len=ast.literal_eval(FLAGS.conv_len),
        num_classes=get_num_classes(),
        learning_rate=FLAGS.learning_rate,
        gradient_clipping_norm=FLAGS.gradient_clipping_norm,
        cell_type=FLAGS.cell_type,
        batch_norm=FLAGS.batch_norm,
        dropout=FLAGS.dropout)
    estimator = tf.estimator.Estimator(
        model_fn=model_fn,
        config=run_config,
        params=model_params)
    train_spec = tf.estimator.TrainSpec(
        input_fn=get_input_fn(
            mode=tf.estimator.ModeKeys.TRAIN,
            tfrecord_pattern=FLAGS.training_data,
            batch_size=FLAGS.batch_size),
        max_steps=FLAGS.steps)
    eval_spec = tf.estimator.EvalSpec(
        input_fn=get_input_fn(
            mode=tf.estimator.ModeKeys.EVAL,
            tfrecord_pattern=FLAGS.eval_data,
            batch_size=FLAGS.batch_size)
        )
    return estimator, train_spec, eval_spec

def main(unused_args):
    estimator, train_spec, eval_spec = create_estimator_and_specs(
        run_config=tf.estimator.RunConfig(
            model_dir=FLAGS.model_dir,
            save_checkpoints_secs=300,
            save_summary_steps=100)
        )
    tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)

I want to load my trained model into above code and calculate confusion matrix for 345 categories.

Shuo Yang
  • 59
  • 2

2 Answers2

1

You can use library function tf.confusion_matrix

tf.confusion_matrix(
    labels,
    predictions,
    num_classes=None,
    dtype=tf.int32,
    name=None,
    weights=None
)

Computes the confusion matrix from predictions and labels.

tf.confusion_matrix([1, 2, 4], [2, 2, 4]) ==>
      [[0 0 0 0 0]
       [0 0 1 0 0]
       [0 0 1 0 0]
       [0 0 0 0 0]
       [0 0 0 0 1]]

Coming to your case, Following code might help you:

labels = list(test_set.target)
predictions = list(estimator.predict(input_fn=test_input_fn))
confusion_matrix = tf.confusion_matrix(labels, predictions)
Arjun Kava
  • 5,303
  • 3
  • 20
  • 20
  • Thank you for the examples provided, especially the in-context one (+1). For the sake of completeness please post a modified version of ["Update 2" in from this question](https://stackoverflow.com/questions/49774035/how-to-classify-a-quickdraw-doodle-using-tensorflows-sketch-rnn-tutorial) to compute and print the confusion_matrix. – George Profenza Feb 09 '19 at 10:26
1

I don't know how to load my trained model files into the code and test it

Use Datasets for Estimators

The tf.data module contains a collection of classes that allows you to easily load data, manipulate it, and pipe it into your model.

  • Reading in-memory data from numpy arrays.
  • Reading lines from a csv file.

how to get the classification score after the softmax layer (used to calculate Confusion matrix)

Use tf.keras, a high-level API to build and train models in TensorFlow

test_dataset = keras.datasets.test_dataset

(train_images, train_labels), (test_images, test_labels) = test_dataset.load_data()
piet.t
  • 11,718
  • 21
  • 43
  • 52
Tamara Koliada
  • 1,200
  • 2
  • 14
  • 31