Tensorflow Callback as Custom Metric for CTC

In an attempt to yield more metrics during the training of my model (written in TensorFlow version 2.1.0), like the Character Error Rate (CER) and Word Error Rate (WER), I created a callback to pass to the fit function of my model. It is able to generate the CER and WER at the end of an epoch.

It's my second choice as I wanted to create a custom metric for this, but you can only use keras backend functionality for custom metrics. Does anyone have any advice on how to convert the callback below into a Custom Metric (which can then be calculated during training on the validation and/or training data)?

Some roadblocks I encountered are:

  • Failure to convert the K.ctc_decode result to a sparse tensor
  • How can you calculate a distance like edit-distance using the Keras backend?
class Metrics(tf.keras.callbacks.Callback):
    def __init__(self, valid_data, steps):
        valid_data is a TFRecordDataset with batches of 100 elements per batch, shuffled and repeated infinitely. 
        steps define the amount of batches per epoch
        super(Metrics, self).__init__()
        self.valid_data = valid_data
        self.steps = steps

    def on_train_begin(self, logs={}):
        self.cer = []
        self.wer = []
    def on_epoch_end(self, epoch, logs={}):

        imgs = []
        labels = []
        for idx, (img, label) in enumerate(self.valid_data.as_numpy_iterator()):
            if idx >= self.steps:

        imgs = np.array(imgs)
        labels = np.array(labels)

        out = self.model.predict((batch for batch in imgs))        
        input_length = len(max(out, key=len))

        out = np.asarray(out)
        out_len = np.asarray([input_length for _ in range(len(out))])

        decode, log = K.ctc_decode(out,

        decode = [[[int(p) for p in x if p != -1] for x in y] for y in decode][0]

        for (pred, lab) in zip(decode, labels):
            dist = editdistance.eval(pred, lab)
            self.cer.append(dist / (max(len(pred), len(lab))))
            self.wer.append(not np.array_equal(pred, lab))

        print("Mean CER: {}".format(np.mean([self.cer], axis=1)[0]))
        print("Mean WER: {}".format(np.mean([self.wer], axis=1)[0]))


  • Solved in TF 2.3.1, but should apply for previous versions of 2.x as well.

    Some remarks:

    • Information on how to properly implement a Tensorflow Custom Metric is scarce. The question implied the use of a callback to implement the metric. This has longer epochs as a consequence (due to the explicit extra calculation of the metric on_epoch_end), or so I believe. Implementing it as a subclass of tensorflow.keras.metrics.Metric seems the right way, and yields results (if verbose is set correctly) while the epoch is ongoing.
    • Calculating the edit distance for the CER is quite easily performed using tf.edit_distance (using sparse tensors), this can subsequently be used to calculate the WER using some tf logic.
    • Alas, I am yet to find out how to implement both the CER and WER in one metric (as it has quite some duplicate code), if anyone knows how to do so, please contact me.
    • Custom metrics can simply be added into the compilation of your TF model: self.model.compile(optimizer=opt, loss=loss, metrics=[CERMetric(), WERMetric()])
    class CERMetric(tf.keras.metrics.Metric):
        A custom Keras metric to compute the Character Error Rate
        def __init__(self, name='CER_metric', **kwargs):
            super(CERMetric, self).__init__(name=name, **kwargs)
            self.cer_accumulator = self.add_weight(name="total_cer", initializer="zeros")
            self.counter = self.add_weight(name="cer_count", initializer="zeros")
        def update_state(self, y_true, y_pred, sample_weight=None):
            input_shape = K.shape(y_pred)
            input_length = tf.ones(shape=input_shape[0]) * K.cast(input_shape[1], 'float32')
            decode, log = K.ctc_decode(y_pred,
            decode = K.ctc_label_dense_to_sparse(decode[0], K.cast(input_length, 'int32'))
            y_true_sparse = K.ctc_label_dense_to_sparse(y_true, K.cast(input_length, 'int32'))
            decode = tf.sparse.retain(decode, tf.not_equal(decode.values, -1))
            distance = tf.edit_distance(decode, y_true_sparse, normalize=True)
        def result(self):
            return tf.math.divide_no_nan(self.cer_accumulator, self.counter)
        def reset_states(self):
    class WERMetric(tf.keras.metrics.Metric):
        A custom Keras metric to compute the Word Error Rate
        def __init__(self, name='WER_metric', **kwargs):
            super(WERMetric, self).__init__(name=name, **kwargs)
            self.wer_accumulator = self.add_weight(name="total_wer", initializer="zeros")
            self.counter = self.add_weight(name="wer_count", initializer="zeros")
        def update_state(self, y_true, y_pred, sample_weight=None):
            input_shape = K.shape(y_pred)
            input_length = tf.ones(shape=input_shape[0]) * K.cast(input_shape[1], 'float32')
            decode, log = K.ctc_decode(y_pred,
            decode = K.ctc_label_dense_to_sparse(decode[0], K.cast(input_length, 'int32'))
            y_true_sparse = K.ctc_label_dense_to_sparse(y_true, K.cast(input_length, 'int32'))
            decode = tf.sparse.retain(decode, tf.not_equal(decode.values, -1))
            distance = tf.edit_distance(decode, y_true_sparse, normalize=True)
            correct_words_amount = tf.reduce_sum(tf.cast(tf.not_equal(distance, 0), tf.float32))
        def result(self):
            return tf.math.divide_no_nan(self.wer_accumulator, self.counter)
        def reset_states(self):