Search code examples
pythonkerasmetricsimage-segmentation

How to get iou of single class in keras semantic segmentation?


I am using the Image segmentation guide by fchollet to perform semantic segmentation. I have attempted modifying the guide to suit my dataset by labelling the 8-bit img mask values into 1 and 2 like in the Oxford Pets dataset. (which will be subtracted to 0 and 1 in class OxfordPets(keras.utils.Sequence):)

Question is how do I get the IoU metric of a single class (e.g 1)?

I have tried different metrics suggested by Stack Overflow but most of suggest using MeanIoU which I tried but I have gotten nan loss as a result. Here is an example of a mask after using autocontrast. PIL.ImageOps.autocontrast(load_img(val_target_img_paths[i])) enter image description here

The model seems to train well but the accuracy was decreasing over time.

Also, can someone help explain how the metric score can be calculated from y_true and y_pred? I don't quite fully understand when the label value is used in the IoU metric calculation.


Solution

  • I had a similar problem back then. I used jaccard_distance_loss and dice_metric. They are based on IoU. My task was a binary segmentation, so I guess you might have to modify the code in case you want to use it for a multi-label classification problem.

    from keras import backend as K
    
    def jaccard_distance_loss(y_true, y_pred, smooth=100):
        """
        Jaccard = (|X & Y|)/ (|X|+ |Y| - |X & Y|)
                = sum(|A*B|)/(sum(|A|)+sum(|B|)-sum(|A*B|))
        
        The jaccard distance loss is usefull for unbalanced datasets. This has been
        shifted so it converges on 0 and is smoothed to avoid exploding or disapearing
        gradient.
        
        Ref: https://en.wikipedia.org/wiki/Jaccard_index
        
        @url: https://gist.github.com/wassname/f1452b748efcbeb4cb9b1d059dce6f96
        @author: wassname
        """
        intersection = K.sum(K.sum(K.abs(y_true * y_pred), axis=-1))
        sum_ = K.sum(K.sum(K.abs(y_true) + K.abs(y_pred), axis=-1))
        jac = (intersection + smooth) / (sum_ - intersection + smooth)
        return (1 - jac) * smooth
    
    def dice_metric(y_pred, y_true):
        intersection = K.sum(K.sum(K.abs(y_true * y_pred), axis=-1))
        union = K.sum(K.sum(K.abs(y_true) + K.abs(y_pred), axis=-1))
        # if y_pred.sum() == 0 and y_pred.sum() == 0:
        #     return 1.0
    
        return 2*intersection / union
    
    # Example
    size = 10
    
    y_true = np.zeros(shape=(size,size))
    y_true[3:6,3:6] = 1
    
    y_pred = np.zeros(shape=(size,size))
    y_pred[3:5,3:5] = 1
    
    loss = jaccard_distance_loss(y_true,y_pred)
    
    metric = dice_metric(y_pred,y_true)
    
    print(f"loss: {loss}")
    print(f"dice_metric: {metric}")
    
    
    loss: 4.587155963302747
    dice_metric: 0.6153846153846154