Search code examples
pythontensorflowkeras

How to make Keras compute a certain metric on validation data only?


I'm using tf.keras with TensorFlow 1.14.0. I have implemented a custom metric that is quite computationally intensive and it slows down the training process if I simply add it to the list of metrics provided as model.compile(..., metrics=[...]).

How do I make Keras skip computation of the metric during training iterations but compute it on validation data (and print it) at the end of each epoch?


Solution

  • To do this you can create a tf.Variable in the metric calculation that determines if the calculation goes ahead and then updates it when a test is run using a callback. e.g.

    class MyCustomMetric(tf.keras.metrics.Metrics):
    
        def __init__(self, **kwargs):
            # Initialise as normal and add flag variable for when to run computation
            super(MyCustomMetric, self).__init__(**kwargs)
            self.metric_variable = self.add_weight(name='metric_variable', initializer='zeros')
            self.on = tf.Variable(False)
    
        def update_state(self, y_true, y_pred, sample_weight=None):
            # Use conditional to determine if computation is done
            if self.on:
                # run computation
                self.metric_variable.assign_add(computation_result)
    
        def result(self):
            return self.metric_variable
    
        def reset_states(self):
            self.metric_variable.assign(0.)
    
    class ToggleMetrics(tf.keras.callbacks.Callback):
        '''On test begin (i.e. when evaluate() is called or 
         validation data is run during fit()) toggle metric flag '''
        def on_test_begin(self, logs):
            for metric in self.model.metrics:
                if 'MyCustomMetric' in metric.name:
                    metric.on.assign(True)
        def on_test_end(self,  logs):
            for metric in self.model.metrics:
                if 'MyCustomMetric' in metric.name:
                    metric.on.assign(False)