Search code examples
pythontensorflowkerasgenerative-adversarial-network

Change yTrue in custom metrics


I'm trying to implement GAN in Keras, and I want to use One-sided label smoothing trick, i.e. put the label of True image to be 0.9 instead of 1. However, now the built-in metrics binary_crossentropy does not do the correct thing, it's always 0 for True image.

Then I tried to implement my own metrics in Keras. I want to convert all 0.9 label to be 1, but I'm new to Keras and I don't know how to do that. Here's what I intend:

# Just a pseudo code
def custom_metrics(y_true, y_pred):
    if K.equal(y_true, [[0.9]]):
        y_true = y_true+0.1
    return metrics.binary_accuracy(y_true, y_pred)

How should I compare and change the y_true label? Thanks in advance!


EDIT: The output of the following code is:

def custom_metrics(y_true, y_pred):
    print(K.shape(y_true))
    print(K.shape(y_pred))
    y_true = K.switch(K.equal(y_true, 0.9), K.ones_like(y_true), K.zeros_like(y_true))
    return metrics.binary_accuracy(y_true, y_pred)

Tensor("Shape:0", shape=(2,), dtype=int32) Tensor("Shape_1:0", shape=(2,), dtype=int32)

ValueError: Shape must be rank 0 but is rank 2 for 'cond/Switch' (op: 'Switch') with input shapes: [?,?], [?,?].


Solution

  • You can use tf.where:

    y_true = tf.where(K.equal(y_true, 0.9), tf.ones_like(y_true), tf.zeros_like(y_true))
    

    Alternatively, You can use keras.backend.switch function for that.

    keras.backend.switch(condition, then_expression, else_expression)
    

    Your custom metrics function would look something like below:

    def custom_metrics(y_true, y_pred):
        y_true = K.switch(K.equal(y_true, 0.9),K.ones_like(y_true), K.zeros_like(y_true))
        return metrics.binary_accuracy(y_true, y_pred)
    

    Test code:

    def test_function(y_true):
        print(K.eval(y_true))
        y_true = K.switch(K.equal(y_true, 0.9),K.ones_like(y_true), K.zeros_like(y_true))
        print(K.eval(y_true))
    
    y_true = K.variable(np.array([0, 0, 0, 0, 0, 0.9, 0.9, 0.9, 0.9, 0.9]))
    test_function(y_true)  
    

    output:

    [0.  0.  0.  0.  0.  0.9 0.9 0.9 0.9 0.9]
    [0. 0. 0. 0. 0. 1. 1. 1. 1. 1.]