Search code examples
kerasdeep-learningloss-function

Gettting the training data in ecery batch keras


I would like to know if it is possible to get the set of training data used in every batch keras.

It's easy to get y_true and y_pred, but I want to know the set of trainig data used to predict in that batch.

def my_loss(y_true, y_pred):
    loss=K.mean(K.abs(y_true-y_pred))
    return loss

model.compile(loss=my_loss, optimizer='rmsprop', metrics=['mae'])

This is OK

but I want something like this:

def my_loss(y_true, y_pred, x_train):

my_loss() missing 1 required positional argument: 'x_train'

Thanks for any help


Solution

  • If you want to pass parameters other then y_true and y_pred, You can define your custom loss like this:

    def custom_loss(x_train):
    
        def my_loss(y_true, y_pred):
            loss=K.mean(K.abs(y_true-y_pred))
            # do something with x_train
            return loss
    
        return my_loss
    

    While compiling you can pass tensor of shape same as your x_train.

    input_tensor = Input(shape=input_shape) #specify your input shape, same as x_train.
    model.compile(loss=custom_loss(input_tensor), optimizer='rmsprop', metrics=['mae'])
    

    This is how you can define your custom loss. Further you want get current batch of x_train, Now batching is something that you'll have to handle yourself.
    Finally while training you can use model.train_on_batch.