0

I've created a neural network in tensorflow. This network is multilabel. Ergo: it tries to predict multiple output labels for one input set, in this case three. Currently I use this code to test how accurate my network is at predicting the three labels:

_, indices_1 = tf.nn.top_k(prediction, 3)
_, indices_2 = tf.nn.top_k(item_data, 3)
correct = tf.equal(indices_1, indices_2)
accuracy = tf.reduce_mean(tf.cast(correct, 'float'))
percentage = accuracy.eval({champion_data:input_data, item_data:output_data})

That code works fine. The problem is now that I'm trying to create code that tests if the top 3 items it finds in indices_1 are amongst the top 5 images in indices_2. I know tensorflow has an in_top_k() method, but as far as I know that doesn't accept multilabel. Currently I've been trying to compare them using a for loop:

_, indices_1 = tf.nn.top_k(prediction, 5)
_, indices_2 = tf.nn.top_k(item_data, 3)
indices_1 = tf.unpack(tf.transpose(indices_1, (1, 0)))
indices_2 = tf.unpack(tf.transpose(indices_2, (1, 0)))
correct = []
for element in indices_1:
    for element_2 in indices_2:
        if element == element_2:
            correct.append(True)
        else:
            correct.append(False)
accuracy = tf.reduce_mean(tf.cast(correct, 'float'))
percentage = accuracy.eval({champion_data:input_data, item_data:output_data})

However, that doesn't work. The code runs but my accuracy is always 0.0.

So I have one of two questions:

1) Is there an easy replacement for in_top_k() that accepts multilabel classification that I can use instead of custom writing code?

2) If not 1: what am I doing wrong that results in me getting an accuracy of 0.0?

Hasse Iona
  • 101
  • 1

1 Answers1

0

When you do

correct = tf.equal(indices_1, indices_2)

you are checking not just whether those two indices contain the same elements but whether they contain the same elements in the same positions. This doesn't sound like what you want.

The setdiff1d op will tell you which indices are in indices_1 but not in indices_2, which you can then use to count errors.

I think being too strict with the correctness check might be what is causing you to get a wrong result.

Alexandre Passos
  • 5,186
  • 1
  • 14
  • 19
  • Thanks so much! This is a big step in the right direction. It did require me to update my tensorflow though, as my version did not yet have setdiff1d. Would you mind elaborating on how I can count the errors? I've tried a few things, but I can't seem to figure out how to know many differences setdif1d has found. – Hasse Iona Dec 05 '16 at 13:38