Search code examples

How to compute Topk IoU in semantic segmentation using PyTorch?

Given outputs of a neural network outputs and a ground-truth label. outputs.shape = (N, C, H, W) and label.shape = (N, H ,W), where N is the batch size, C is the number of classes, H and W are crop sizes. Each element of label is in the range of [0, ..., C-1]. That is, 0 <= label[i,j,k] <= C-1 for all i,j,k. I want to compute top3 IoU of outputs with respect to label, so I need a one-hot version of its top3 output. For example

N, C, H, W = 1, 4, 2, 2
outputs = torch.rand((N, C, H, W))
label = torch.arange(C).reshape(N, H, W)
_, index = torch.topk(outputs, k=3, dim=1)
top3 = torch.zeros((N, C, H, W))

for i in range(N):
    for j in range(H):
        for k in range(W):
            c = label[i, j, k]
            if c in index[i, :, j, k]:
                top3[i, c, j, k] = 1


tensor([[[[0.8002, 0.6733],
          [0.7034, 0.5039]],

         [[0.8401, 0.9226],
          [0.7963, 0.6157]],

         [[0.1063, 0.0310],
          [0.2489, 0.9920]],

         [[0.8279, 0.9109],
          [0.4737, 0.2299]]]])


tensor([[[0, 1],
         [2, 3]]])


tensor([[[[1, 1],
          [1, 2]],

         [[3, 3],
          [0, 1]],

         [[0, 0],
          [3, 0]]]])


tensor([[[[1., 0.],
          [0., 0.]],

         [[0., 1.],
          [0., 0.]],

         [[0., 0.],
          [0., 0.]],

         [[0., 0.],
          [0., 0.]]]])

Then I can use top3 to compute top3 IoU. There is a version to compute top3 pixel-wise accuracy, but it creates a lot of false 1s and so cannot be used to compute IoU.

expand = torch.nn.functional.one_hot(index)
top3 = expand.transpose(1, 4).sum(dim=4) 


tensor([[[[0, 1],
          [0, 0]],

         [[1, 0],
          [1, 1]],

         [[1, 1],
          [1, 1]],

         [[1, 1],
          [1, 1]]]])


  • I think an efficient (vectorized) solution would look something like this:

    1. Compute top k index maps [N,k,H,W]
    2. Check each of the k index maps against the ground truth labels [N,k,H,W]
    3. Sum across the k maps

    followed by some arithmetic to get the specific value you care about.

    #1. Compute top k index maps
    k = 3 # but works for arbitrary k <= C
    topk = torch.topk(outputs,k = k, dim = 1)  # [N,C,H,W] in, [N,k,H,W] out
    #2. Check for label equality against each index map
    labels_expanded = labels.unsqueeze(1).expand(N,k,H,W] # tensor view of size [N,k,H,W]
    tp = torch.where(labels_expanded == topk,1,0) 
    # result is # [N,k,H,W] where a 1 at position [i,j,m,n] means that label[i,m,n] = topk[i,j,m,n] (a true positive) and a 0 indicates otherwise 
    #3. sum across all k maps
    tp_sum = tp.sum(dim = 1) # now of shape [N,H,W]

    All that remains is a bit of elementary IOU math (as IOU is TP/ (FP + FN + TP) for pixelwise tasks), but I'll leave the specifics of that up to you as you may want to exclude any background class (otherwise IOU reduces simply to accuracy (TP / total, or accuracy). Personally I find the notion of IOU for dense semantic segmentation tasks where every pixel has a label to be semantically misleading but that's outside of the scope of this question.