0

I'm trying to create a masked version of SparseCategoricalAccuracy in tf 2.0 that can be passed to the Keras api via compile(metrics=[masked_accuracy_fn()].

The function looks like:

def get_masked_acc_metric_fn(ignore_label=-1):
    """Gets the masked accuracy function."""
    def masked_acc_fn(y_true, y_pred):
        """Masked accuracy."""
        y_true = tf.squeeze(y_true)
        # Create mask for time steps we don't care about
        mask = tf.not_equal(y_true, ignore_label)
        masked_acc = tf.keras.metrics.SparseCategoricalAccuracy(
            'test_masked_accuracy', dtype=tf.float32)(y_true, y_pred, sample_weight=mask)
        return masked_acc

    return masked_acc_fn

This works in Eager mode. However, when running in graph mode, I get the error:

ValueError: tf.function-decorated function tried to create variables on non-first call
Feynman27
  • 3,049
  • 6
  • 30
  • 39

1 Answers1

0

This seems to work as a temporary workaround:

class MaskedSparseCategoricalAccuracy(tf.keras.metrics.SparseCategoricalAccuracy):
    def __init__(self, name="masked_sparse_categorical_accuracy", dtype=None):
        super(MaskedSparseCategoricalAccuracy, self).__init__(name, dtype=dtype)

    def update_state(self, y_true, y_pred, ignore_label=-1):
        sample_weight = tf.not_equal(y_true, ignore_label)
        super(MaskedSparseCategoricalAccuracy, self).update_state(y_true, y_pred, sample_weight)
Feynman27
  • 3,049
  • 6
  • 30
  • 39