Search code examples
tensorflowapache-beammetricstfxtensorflow-model-analysis

How to make a custom metric available to TFMA/Beam?


I have created a custom Keras metric, similar to the demo implementation below:

import tensorflow as tf

class MyMetric(tf.keras.metrics.Mean):

    def __init__(self, name='my_metric', dtype=None):
        super(MyMetric, self).__init__(name=name, dtype=dtype)

    def update_state(self, y_true, y_pred, sample_weight=None):
        return super(MyMetric, self).update_state(
            y_pred, sample_weight=sample_weight)

I have turned the implementation into a Python module with the init/main files and added the path to the system's PYTHONPATH. I can use the metric when I train the Keras model.

Unfortunately, I haven't found a way to make the custom metric available to TensorFlow Model Analysis (TFMA).

In my interactive context notebook, I can load the metric when I create the eval_config.

import tensorflow as tf
import tensorflow_model_analysis as tfma 
from mymetric.metric import MyMetric

metrics = [MyMetric()]
metrics_specs = tfma.metrics.specs_from_metrics(metrics)

eval_config = tfma.EvalConfig(
        model_specs=[tfma.ModelSpec(label_key='label_xf')],
        metrics_specs=metrics_specs,
        slicing_specs=[tfma.SlicingSpec()]
)
evaluator = Evaluator(
    examples=example_gen.outputs['examples'],
    model=trainer.outputs['model'], 
    baseline_model=model_resolver.outputs['model'],
    eval_config=eval_config)

When I try to execute the evaluator, the metric is listed as in the metric specifications

metrics_specs {
  metrics {
    class_name: "MyMetric"
    config: "{\"dtype\": \"float32\", \"name\": \"my_metric\"}"
    threshold {
    }
  }
}

but the execution fails with the error

ValueError: Unknown metric function: MyMetric

Since the metric calculation is executed via Apache Beam's executor.Do function, I assume that Beam can't find the module (even though it is on the PYTHONPATH). If that is the case, how can I make the module available to Apache Beam beyond the PYTHONPATH configuration?

Traceback:

/usr/local/lib/python3.6/dist-packages/tensorflow_model_analysis/metrics/metric_specs.py in _deserialize_tf_metric(metric_config, custom_objects)
    741   cls_name, cfg = _tf_class_and_config(metric_config)
    742   with tf.keras.utils.custom_object_scope(custom_objects):
--> 743     return tf.keras.metrics.deserialize({'class_name': cls_name, 'config': cfg})
    744 
    745 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/metrics.py in deserialize(config, custom_objects)
   3441       module_objects=globals(),
   3442       custom_objects=custom_objects,
-> 3443       printable_module_name='metric function')
   3444 
   3445 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/utils/generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
    345     config = identifier
    346     (cls, cls_config) = class_and_config_for_serialized_keras_object(
--> 347         config, module_objects, custom_objects, printable_module_name)
    348 
    349     if hasattr(cls, 'from_config'):

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/utils/generic_utils.py in class_and_config_for_serialized_keras_object(config, module_objects, custom_objects, printable_module_name)
    294   cls = get_registered_object(class_name, custom_objects, module_objects)
    295   if cls is None:
--> 296     raise ValueError('Unknown ' + printable_module_name + ': ' + class_name)
    297 
    298   cls_config = config['config']

ValueError: Unknown metric function: MyMetric

Solution

  • You need to specify the module so that TFX knows where to find your MyMetric class. One way of doing this is to specify it as part of the metric specs:

    from tensorflow_model_analysis import config

    metric_config = [config.MetricConfig(class_name='MyMetric', module='mymodule.mymetric')]

    metrics_specs = [config.MetricsSpec(metrics=metric_config)]

    You will also need to create a module called mymodule and put your MyMetric class in in mymetric.py for this to work. Also make sure that the module is accessible from where you are executing the code (which should be the case if you have added it to your PYTHONPATH).