Search code examples

Custom metric access X input data

I'd like to write a custom metric for a spelling correction model that counts correctly substituted letters that were previously incorrect. And it should be counted incorrectly substituted letters that were previously correct.

That's why I need access to the x_input data. Unfortunately, only y_true and y_pred are accessible by default. Is there a workaround to get to the matching x_input?


def custom_metric(y_true, y_pred):


def custom_metric(x_input, y_true, y_pred):


  • def custom_loss(x_input):
        def loss_fn(y_true, y_pred):
            # Use your x_input here directly
            return #Your loss value
        return loss_fn
    model = # Define your model
    # Values of y_true and y_pred will be passed implicitly by Keras

    Remember that x_input will be having same values across all batches of input while model is getting trained.


    Since you need x_input data only of every batch for estimating during the loss function and you are having your own custom loss function, why don't you pass the x_input as labels. Something like this:, y=x_input)
    def custom_loss(y_true, y_pred):
      # y_true corresponds to x_input data

    If you need x_input and you need to pass some other data, you can do like this:, y=[x_input, other_data])

    You just need to decouple the data in y_true now.