Search code examples
kerasneural-networkdeep-learningloss-functionimbalanced-data

Custom loss function (focal loss) input size error in Keras


I am using a neutral network to do multi-class classification. There're 3 imbalanced classes so I'd like to use the focal loss to handle the in-balance. So I use custom loss function to fit in Keras sequential model. I tried multiple versions of code for focal loss function I found online, but they return the same error message, basically saying the input size is the bath size while expected 1. Could anyone have a look at the issue and let me know if you can fix it? I really appreciate it!!!

model = build_keras_model(x_train, name='training1')

class FocalLoss(keras.losses.Loss):
    def __init__(self, gamma=2., alpha=4.,
             reduction = tf.keras.losses.Reduction.AUTO, name='focal_loss'):

    super(FocalLoss, self).__init__(reduction=reduction,
                                    name=name)
    self.gamma = float(gamma)
    self.alpha = float(alpha)

def call(self, y_true, y_pred):

        epsilon = 1.e-9
        y_true = tf.convert_to_tensor(y_true, tf.float32)
        y_pred = tf.convert_to_tensor(y_pred, tf.float32)
        model_out = tf.add(y_pred, epsilon)
        ce = tf.multiply(y_true, -tf.math.log(model_out))
        weight = tf.multiply(y_true, tf.pow(
            tf.subtract(1., model_out), self.gamma))
        fl = tf.multiply(self.alpha, tf.multiply(weight, ce))
        reduced_fl = tf.reduce_max(fl, axis=1)
        return tf.reduce_mean(reduced_fl)

model.compile(optimizer = tf.keras.optimizers.Adam(0.001),
          loss = FocalLoss(alpha=1),
          metrics=['accuracy'])
​
class_weight = {0: 1.,
            1: 6.,
            2: 6.}

# fit the model (train for 5 epochs) history = model.fit(x=x_train, y=y_train, batch_size=64, epochs=5, class_weight = class_weight)

ValueError: Can not squeeze dim[0], expected a dimension of 1, got 64 for 'loss/output_1_loss/weighted_loss/Squeeze' (op: 'Squeeze') with input shapes: [64].

Solution

  • You are facing the issue that you are leveraging some helper class that is designed for doing some logic for you, but unfortunately, its documentation is not very clear about what exactly it does for you and, hence, what exactly you need to do on your own.

    In this case you use tf.keras.losses.Loss. All you need to do is implement call() (and optionally __init__). Unfortunately, the documentation doesn't state at all what it expects call() to return. But as you need to specify a reduction in __init__(), we can assume that call() is expected to not only return a single number. Otherwise reduction would be useless. In other words: the error is telling you that call() returns a single number while it is expected to return 64 numbers (your batch size).

    So, instead of reducing the batch into a single number yourself (by calling tf.reduce_mean(reduced_fl)), let the helper class do this for you and just return reduced_f1 directly. Currently you use reduction=tf.keras.losses.Reduction.AUTO which is likely what you want.