I have created a custom Keras metric, similar to the demo implementation below:
import tensorflow as tf
class MyMetric(tf.keras.metrics.Mean):
def __init__(self, name='my_metric', dtype=None):
super(MyMetric, self).__init__(name=name, dtype=dtype)
def update_state(self, y_true, y_pred, sample_weight=None):
return super(MyMetric, self).update_state(
y_pred, sample_weight=sample_weight)
I have turned the implementation into a Python module with the init/main files and added the path to the system's PYTHONPATH
.
I can use the metric when I train the Keras model.
Unfortunately, I haven't found a way to make the custom metric available to TensorFlow Model Analysis (TFMA).
In my interactive context notebook, I can load the metric when I create the eval_config
.
import tensorflow as tf
import tensorflow_model_analysis as tfma
from mymetric.metric import MyMetric
metrics = [MyMetric()]
metrics_specs = tfma.metrics.specs_from_metrics(metrics)
eval_config = tfma.EvalConfig(
model_specs=[tfma.ModelSpec(label_key='label_xf')],
metrics_specs=metrics_specs,
slicing_specs=[tfma.SlicingSpec()]
)
evaluator = Evaluator(
examples=example_gen.outputs['examples'],
model=trainer.outputs['model'],
baseline_model=model_resolver.outputs['model'],
eval_config=eval_config)
When I try to execute the evaluator
, the metric is listed as in the metric specifications
metrics_specs {
metrics {
class_name: "MyMetric"
config: "{\"dtype\": \"float32\", \"name\": \"my_metric\"}"
threshold {
}
}
}
but the execution fails with the error
ValueError: Unknown metric function: MyMetric
Since the metric calculation is executed via Apache Beam's executor.Do
function, I assume that Beam can't find the module (even though it is on the PYTHONPATH). If that is the case, how can I make the module available to Apache Beam beyond the PYTHONPATH configuration?
Traceback:
/usr/local/lib/python3.6/dist-packages/tensorflow_model_analysis/metrics/metric_specs.py in _deserialize_tf_metric(metric_config, custom_objects)
741 cls_name, cfg = _tf_class_and_config(metric_config)
742 with tf.keras.utils.custom_object_scope(custom_objects):
--> 743 return tf.keras.metrics.deserialize({'class_name': cls_name, 'config': cfg})
744
745
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/metrics.py in deserialize(config, custom_objects)
3441 module_objects=globals(),
3442 custom_objects=custom_objects,
-> 3443 printable_module_name='metric function')
3444
3445
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/utils/generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
345 config = identifier
346 (cls, cls_config) = class_and_config_for_serialized_keras_object(
--> 347 config, module_objects, custom_objects, printable_module_name)
348
349 if hasattr(cls, 'from_config'):
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/utils/generic_utils.py in class_and_config_for_serialized_keras_object(config, module_objects, custom_objects, printable_module_name)
294 cls = get_registered_object(class_name, custom_objects, module_objects)
295 if cls is None:
--> 296 raise ValueError('Unknown ' + printable_module_name + ': ' + class_name)
297
298 cls_config = config['config']
ValueError: Unknown metric function: MyMetric