Search code examples
pythonkeraskeras-2

Weighing Training Data for Keras


Problem

I want to train a keras2 neural network (theano backend) with data of variable relevance. That means some of the samples are less important than others. They shall affect the training less than others. However I'm not able to simply omit them completely (I have a time series that goes into Conv1D layers).

Question

How can I tell keras to weigh certain training data samples less than others during the training?

Idea

I'm thinking about defining an own loss function that takes y_true, y_pred and y_weight as a third argument. Something like:

def mean_squared_error_weighted(y_true, y_pred, y_weight):
    return y_weight * K.mean(K.square(y_pred - y_true), axis=-1)

But how would I let keras know about that third argument?


Solution

  • The fit function of of a keras model accepts an optional argument sample_weight that does exactly what you're looking for. More specifically from keras documentation:

    sample_weight: Optional Numpy array of weights for the training samples, used for weighting the loss function (during training only).