0

How to use MxNet metrics api to calculate accuracy of the multiclass logistic regression classifier with vector labels? Here is an example for labels:

Class1: [1,0,0,0]
Class2: [0,1,0,0]
Class3: [0,0,1,0]
Class4: [0,0,0,1]

The naive way to use this function would produce wrong result as argmax will squash the model output into an index having max probability value

def evaluate_accuracy(data_iterator, ctx, net):
    acc = mx.metric.Accuracy()
    for i, (data, label) in enumerate(data_iterator):
        data = data.as_in_context(ctx)
        label = label.as_in_context(ctx)
        out = net(data)
        p = nd.argmax(out, axis=1)
        acc.update(preds=p, labels=label)
    return acc.get()[1]

My current solution is little hacky:

def evaluate_accuracy(data_iterator, ctx, net):
    acc = mx.metric.Accuracy()
    for i, (data, label) in enumerate(data_iterator):
        data = data.as_in_context(ctx)
        label = label.as_in_context(ctx)
        out = net(data)
        p = nd.argmax(out, axis=1)
        l = nd.argmax(label, axis=1)
        acc.update(preds=p, labels=l)
    return acc.get()[1]
Dimon Buzz
  • 1,130
  • 3
  • 16
  • 35

1 Answers1

0

Accuracy metric is tricky. It doesn't really work with one-hot-encoded labels as a ground-truth.

I find this somewhat counter-intuitive, but you need to pass non-one-hot-encoded labels as a ground-truth, but actual classes (eg, 2 instead of [0,0,1,0]). Otherwise, accuracy would not work in the way you expect. Take a look into my previous reply here - Why MXNet is reporting the incorrect validation accuracy?

Also, MxNet expect classes to start with 0. So, if you have classes starting from 1, then you need to adjust all classes by subtracting 1.

Sergei
  • 1,617
  • 15
  • 31
  • Thanks Sergei. I guess this should be a bug then, since SoftmaxCrossEntropyLoss allows for labels having probability distributions, but metrics API doesn't account for this. – Dimon Buzz Mar 12 '18 at 19:47
  • @Sergey, btw you can answer this too https://stackoverflow.com/questions/49217210/why-normalizing-labels-in-mxnet-makes-accuracy-close-to-100 it looks like 0 based labels is a reason there – Dimon Buzz Mar 12 '18 at 20:45