Search code examples
pythontensorflowkerasloss-function

Is there a differentiable alternative to K.cast?


For a custom Keras loss function, I need to create a float tensor from a bool tensor. Unfortunately, K.cast() is not differentiable and therefore can't be used. Is there an alternative way to do this that is differentiable?

less_than_tau = y_pred < tau
less_than_tau = K.cast(less_than_tau, 'float32')

Solution

  • Dr. Snoopy is right.

    The way you solve for this in deep learning is "soft" functions, such as softmax instead of max.

    In your case, if you want to minimize y-pred relative y-tau, you'd do something like

    switch = sigmoid(y_pred - y_tau)
    loss = switch * true_case + (1. - switch) * false_case