I'm building a model for Time series classification. The data is very unbalanced so I've decided to use a weighted cross entropy function as my loss.
Tensorflow provides tf.nn.weighted_cross_entropy_with_logits but I'm not sure how to use it in TF 2.0. Because my model is build using tf.keras API I was thinking about creating my custom loss function like this:
pos_weight=10
def weighted_cross_entropy_with_logits(y_true,y_pred):
return tf.nn.weighted_cross_entropy_with_logits(y_true,y_pred,pos_weight)
# .....
model.compile(loss=weighted_cross_entropy_with_logits,optimizer="adam",metrics=["acc"])
My question is: is there a way to use tf.nn.weighted_cross_entropy_with_logits with tf.keras API directly?
You can pass the class weights directly to the model.fit
function.
class_weight:
Optional dictionary mapping class indices (integers) to a weight (float) value, used for weighting the loss function (during training only). This can be useful to tell the model to "pay more attention" to samples from an under-represented class.
Such as:
{
0: 0.31,
1: 0.33,
2: 0.36,
3: 0.42,
4: 0.48
}
Edit: JL Meunier answer explains well how to multiply the logits with class weights.