Search code examples
tensorflowkerasmetrics

Custom metric access X input data


I'd like to write a custom metric for a spelling correction model that counts correctly substituted letters that were previously incorrect. And it should be counted incorrectly substituted letters that were previously correct.

That's why I need access to the x_input data. Unfortunately, only y_true and y_pred are accessible by default. Is there a workaround to get to the matching x_input?

Is:

def custom_metric(y_true, y_pred):

Wanted:

def custom_metric(x_input, y_true, y_pred):

Solution

  • def custom_loss(x_input):
        def loss_fn(y_true, y_pred):
            # Use your x_input here directly
            return #Your loss value
        return loss_fn
    
    model = # Define your model
    model.compile(loss=custom_loss(x_input))   
    # Values of y_true and y_pred will be passed implicitly by Keras
    

    Remember that x_input will be having same values across all batches of input while model is getting trained.

    EDIT:

    Since you need x_input data only of every batch for estimating during the loss function and you are having your own custom loss function, why don't you pass the x_input as labels. Something like this:

    model.fit(x=x_input, y=x_input)
    model.compile(loss=custom_loss())
    
    def custom_loss(y_true, y_pred):
      # y_true corresponds to x_input data
    

    If you need x_input and you need to pass some other data, you can do like this:

    model.fit(x=x_input, y=[x_input, other_data])
    model.compile(loss=custom_loss())
    

    You just need to decouple the data in y_true now.