I'm trying to create a masked version of SparseCategoricalAccuracy
in tf 2.0 that can be passed to the Keras api via compile(metrics=[masked_accuracy_fn()]
.
The function looks like:
def get_masked_acc_metric_fn(ignore_label=-1):
"""Gets the masked accuracy function."""
def masked_acc_fn(y_true, y_pred):
"""Masked accuracy."""
y_true = tf.squeeze(y_true)
# Create mask for time steps we don't care about
mask = tf.not_equal(y_true, ignore_label)
masked_acc = tf.keras.metrics.SparseCategoricalAccuracy(
'test_masked_accuracy', dtype=tf.float32)(y_true, y_pred, sample_weight=mask)
return masked_acc
return masked_acc_fn
This works in Eager mode. However, when running in graph mode, I get the error:
ValueError: tf.function-decorated function tried to create variables on non-first call