Search code examples
pythontensorflowkeras

How to monitor a filtered version of a metric in EarlyStopping callback in tensorflow?


I always had this problem. When training neural networks, the validation loss can be noisy (sometimes even the training loss if you are using stochastic layers such as dropout). This is especially true when the dataset is small.

This makes that when using callbacks such as EarlyStopping or ReduceLROnPlateau, these are triggered too early (even using large patience). Also, sometimes I don't want to use large patience in the ReduceLROnPLateau callback.

A solution to this is instead of directly monitoring a certain metric (e.g. val_loss), to monitor a filtered version (across epochs) of the metric (e.g. exponential moving average of val_loss). However, I do not see any easy way to solve this because the callbacks only accept metrics that not depend on the previous epochs. I have tried using a custom training loop to reproduce the functionality of these callbacks with my custom filtered metric, but I don't think it is the correct way. Is there another (simpler) way to do the monitor the filtered version of the loss in the callbacks, without reimplementing the whole functionality of the callbacks?

Edit:

This is what I mean by monitoring a filtered version of a metric. The current EarlyStopping works something like this:

best_loss = float('inf')
best_epoch = 0
for epoch in range(n_epochs):
    # ...
    new_loss = # compute loss of current epoch
    if new_loss < best_loss:
        best_loss = new_loss
        best_epoch = epoch
    if epoch - best_epoch > patience:
        break

Monitoring the filtered metric would be like this:

best_loss = float('inf')
filtered_loss = 10 # example initial value
best_epoch = 0
for epoch in range(n_epochs):
    # ...
    new_loss = # compute loss of current epoch
    filtered_loss = 0.1*new_loss + 0.9*filtered_loss
    if filtered_loss < best_loss:
        best_loss = filtered_loss
        best_epoch = epoch
    if epoch - best_epoch > patience:
        break

Solution

  • Thanks to the contributions of @Pedro Marques and @lescurel, I managed to make a metric class that smooths any other metric across the epochs. It is also able to work when using a validation set during training thanks to the usage of different state variables.

    There is a base class that manages the state logic:

    class SmoothMetric(tf.keras.metrics.Metric):
        def __init__(self, metric, name=None, **kwargs):
            if name is None:
                name = 'smooth_' + metric.name
            super().__init__(name=name, **kwargs)
            self.metric = metric
            self.in_test_step = tf.Variable(False)
            self.prev_in_test_step = tf.Variable(False)
            self.train_states = {}
            self.test_states = {}
    
        def add_state_variable(self, name, **kwargs):
            self.train_states[name] = tf.Variable(**kwargs)
            self.test_states[name] = tf.Variable(**kwargs)
    
        def get_state_variable(self, name):
            return tf.cond(self.in_test_step, lambda: self.test_states[name], lambda: self.train_states[name])
    
        def update_state(self, *args, **kwargs):
            self.metric.update_state(*args, **kwargs)
            self.prev_in_test_step.assign(self.in_test_step)
    
        def update_filter_state_previous_epoch(self):
            raise NotImplementedError('Needs to be overwritten')
    
        def reset_state(self, *args, **kwargs):
            # update state of the previous mode
            tmp = self.in_test_step.read_value()
            self.in_test_step.assign(self.prev_in_test_step)
            self.update_filter_state_previous_epoch()
            self.in_test_step.assign(tmp)
            self.metric.reset_state(*args, **kwargs)
    

    Then, I define a class that inherits from the previous, which implements the filter. In this case, it implements the Exponential Moving Average, but it is possible to make other filters such as the Simple Moving Average.

    class SmoothMetricEMA(SmoothMetric):
        def __init__(self, metric, alpha=0.1, name=None, **kwargs) -> None:
            super().__init__(metric=metric, name=name, **kwargs)
            self.alpha = alpha
            self.add_state_variable('accum', initial_value=0.0)
            self.add_state_variable('idx', initial_value=0)
    
        def update_filter_state_previous_epoch(self):
            self.get_state_variable('accum').assign(self.result())
            self.get_state_variable('idx').assign_add(1)
    
        def _result(self):
            return self.metric.result()
    
        def _filtered_result(self):
            return (1 - self.alpha) * self.get_state_variable('accum') + self.alpha * self._result()
    
        def result(self):
            # if the idx is zero, the accumulator is 0 and _filtered_result will not give the correct result
            return tf.cond(tf.equal(self.get_state_variable('idx'), 0), self._result, self._filtered_result)
    

    Here is the SMA filter alternative:

    class SmoothMetricSMA(SmoothMetric):
        def __init__(self, metric, n_samples=10, remove_outliers=None, name=None, **kwargs) -> None:
            super().__init__(metric=metric, name=name, **kwargs)
            self.queue_size = n_samples - 1
            self.remove_outliers = remove_outliers
            self.add_state_variable('accum', initial_value=tf.zeros(self.queue_size))
            self.add_state_variable('idx', initial_value=0)
    
        def update_filter_state_previous_epoch(self):
            self.get_state_variable('accum').scatter_update(
                tf.IndexedSlices(self.metric.result(), self.get_state_variable('idx') % self.queue_size))
            self.get_state_variable('idx').assign_add(1)
    
        def _result(self):
            return self.metric.result()
    
        def _filtered_result(self):
            queue = tf.concat([self.get_state_variable('accum')[:self.get_state_variable('idx')], tf.reshape(self._result(), (1,))], axis=0)
            if self.remove_outliers is not None:
                queue = tf.sort(queue)
                eff_size = tf.cast(tf.minimum(self.get_state_variable('idx')+1, self.queue_size), tf.float32)
                queue = queue[tf.cast(tf.math.floor(eff_size*self.remove_outliers/2), tf.int32):tf.cast(tf.math.ceil(eff_size*(1-self.remove_outliers/2)), tf.int32)]
            return tf.reduce_mean(queue)
    
        def result(self):
            # if the idx is zero, the accumulator is 0 and _filtered_result will not give the correct result
            return tf.cond(tf.equal(self.get_state_variable('idx'), 0), self._result, self._filtered_result)
    

    Finally, for the filter to work, it is necessary to use the following callback. This tells the filter whether it is in training or test mode.

    class NotifySmoothMetricsCallback(tf.keras.callbacks.Callback):
        """
        Notifies the smooth metric if it is on train or test step.
        """
        def on_test_begin(self, logs):
            for metric in self.model.metrics:
                if isinstance(metric, SmoothMetric):
                    metric.in_test_step.assign(True)
        def on_test_end(self,  logs):
            for metric in self.model.metrics:
                if isinstance(metric, SmoothMetric):
                    metric.in_test_step.assign(False)
    
        def on_epoch_begin(self, epoch, logs):
            for metric in self.model.metrics:
                if isinstance(metric, SmoothMetric):
                    metric.in_test_step.assign(False)
        def on_epoch_end(self, epoch, logs):
            for metric in self.model.metrics:
                if isinstance(metric, SmoothMetric):
                    metric.in_test_step.assign(False)
        def on_train_begin(self, logs):
            for metric in self.model.metrics:
                if isinstance(metric, SmoothMetric):
                    metric.in_test_step.assign(False)
    

    Here is an example of how to use these classes:

    metrics=[
        SmoothMetricEMA(tf.metrics.MeanSquaredError(), alpha=0.1, name='smooth'),
    ]
    callbacks = [
        tf.keras.callbacks.EarlyStopping(monitor='val_smooth', patience=10, verbose=1),
        NotifySmoothMetricsCallback(),
    ]
    model.compile(optimizer='adam', loss='mse', metrics=metrics)
    
    model.fit(x_train, y_train, batch_size=32, epochs=100,
        validation_data=(x_val, y_val),
        callbacks=callbacks,
    )