How to use MxNet metrics api to calculate accuracy of the multiclass logistic regression classifier with vector labels? Here is an example for labels:
Class1: [1,0,0,0]
Class2: [0,1,0,0]
Class3: [0,0,1,0]
Class4: [0,0,0,1]
The naive way to use this function would produce wrong result as argmax will squash the model output into an index having max probability value
def evaluate_accuracy(data_iterator, ctx, net):
acc = mx.metric.Accuracy()
for i, (data, label) in enumerate(data_iterator):
data = data.as_in_context(ctx)
label = label.as_in_context(ctx)
out = net(data)
p = nd.argmax(out, axis=1)
acc.update(preds=p, labels=label)
return acc.get()[1]
My current solution is little hacky:
def evaluate_accuracy(data_iterator, ctx, net):
acc = mx.metric.Accuracy()
for i, (data, label) in enumerate(data_iterator):
data = data.as_in_context(ctx)
label = label.as_in_context(ctx)
out = net(data)
p = nd.argmax(out, axis=1)
l = nd.argmax(label, axis=1)
acc.update(preds=p, labels=l)
return acc.get()[1]