I'm using tf.keras
with TensorFlow 1.14.0. I have implemented a custom metric that is quite computationally intensive and it slows down the training process if I simply add it to the list of metrics provided as model.compile(..., metrics=[...])
.
How do I make Keras skip computation of the metric during training iterations but compute it on validation data (and print it) at the end of each epoch?
To do this you can create a tf.Variable in the metric calculation that determines if the calculation goes ahead and then updates it when a test is run using a callback. e.g.
class MyCustomMetric(tf.keras.metrics.Metrics):
def __init__(self, **kwargs):
# Initialise as normal and add flag variable for when to run computation
super(MyCustomMetric, self).__init__(**kwargs)
self.metric_variable = self.add_weight(name='metric_variable', initializer='zeros')
self.on = tf.Variable(False)
def update_state(self, y_true, y_pred, sample_weight=None):
# Use conditional to determine if computation is done
if self.on:
# run computation
self.metric_variable.assign_add(computation_result)
def result(self):
return self.metric_variable
def reset_states(self):
self.metric_variable.assign(0.)
class ToggleMetrics(tf.keras.callbacks.Callback):
'''On test begin (i.e. when evaluate() is called or
validation data is run during fit()) toggle metric flag '''
def on_test_begin(self, logs):
for metric in self.model.metrics:
if 'MyCustomMetric' in metric.name:
metric.on.assign(True)
def on_test_end(self, logs):
for metric in self.model.metrics:
if 'MyCustomMetric' in metric.name:
metric.on.assign(False)