Search code examples
validationkerasmonitoring

Using multiple validation sets with keras


I am training a model with keras using the model.fit() method. I would like to use multiple validation sets that should be validated on separately after each training epoch so that i get one loss value for each validation set. If possible they should be both displayed during training and as well be returned by the keras.callbacks.History() callback.

I am thinking of something like this:

history = model.fit(train_data, train_targets,
                    epochs=epochs,
                    batch_size=batch_size,
                    validation_data=[
                        (validation_data1, validation_targets1), 
                        (validation_data2, validation_targets2)],
                    shuffle=True)

I currently have no idea how to implement this. Is it possible to achieve this by writing my own Callback? Or how else would you approach this problem?


Solution

  • I ended up writing my own Callback based on the History callback to solve the problem. I'm not sure if this is the best approach but the following Callback records losses and metrics for the training and validation set like the History callback as well as losses and metrics for additional validation sets passed to the constructor.

    class AdditionalValidationSets(Callback):
        def __init__(self, validation_sets, verbose=0, batch_size=None):
            """
            :param validation_sets:
            a list of 3-tuples (validation_data, validation_targets, validation_set_name)
            or 4-tuples (validation_data, validation_targets, sample_weights, validation_set_name)
            :param verbose:
            verbosity mode, 1 or 0
            :param batch_size:
            batch size to be used when evaluating on the additional datasets
            """
            super(AdditionalValidationSets, self).__init__()
            self.validation_sets = validation_sets
            for validation_set in self.validation_sets:
                if len(validation_set) not in [3, 4]:
                    raise ValueError()
            self.epoch = []
            self.history = {}
            self.verbose = verbose
            self.batch_size = batch_size
    
        def on_train_begin(self, logs=None):
            self.epoch = []
            self.history = {}
    
        def on_epoch_end(self, epoch, logs=None):
            logs = logs or {}
            self.epoch.append(epoch)
    
            # record the same values as History() as well
            for k, v in logs.items():
                self.history.setdefault(k, []).append(v)
    
            # evaluate on the additional validation sets
            for validation_set in self.validation_sets:
                if len(validation_set) == 3:
                    validation_data, validation_targets, validation_set_name = validation_set
                    sample_weights = None
                elif len(validation_set) == 4:
                    validation_data, validation_targets, sample_weights, validation_set_name = validation_set
                else:
                    raise ValueError()
    
                results = self.model.evaluate(x=validation_data,
                                              y=validation_targets,
                                              verbose=self.verbose,
                                              sample_weight=sample_weights,
                                              batch_size=self.batch_size)
    
                for metric, result in zip(self.model.metrics_names,results):
                    valuename = validation_set_name + '_' + metric
                    self.history.setdefault(valuename, []).append(result)
    

    which i am then using like this:

    history = AdditionalValidationSets([(validation_data2, validation_targets2, 'val2')])
    model.fit(train_data, train_targets,
              epochs=epochs,
              batch_size=batch_size,
              validation_data=(validation_data1, validation_targets1),
              callbacks=[history]
              shuffle=True)