Search code examples
labelmaskimage-segmentation

How to change classes labels order in a segmentation mask


I wonder if it's possible to change class labels in a segmentation mask. Suppose I have masks, where the classes order is like that: 0 - class A, 1 - class B, 2 - background. Is there a way to change it to: 0 - background, 1 - class A, 2 - class B?

This is my code: Now my mask is like: 0 - background, 1 - ground glass, 2 - consolidation, 3 - lung other. What I want: 0 - ground glass, 1 - consolidation, 2 - lung other, 3 - background.

Before changing:

i = 31
fig, axes = plt.subplots(2, 2, figsize=(10, 8))
axes[0, 0].imshow(mask_four[i, :, :, 0])
axes[0, 0].set_title('Background')
axes[0, 1].imshow(mask_four[i, :, :, 1])
axes[0, 1].set_title('Ground glass')
axes[1, 0].imshow(mask_four[i, :, :, 2])
axes[1, 0].set_title('Consolidation')
axes[1, 1].imshow(mask_four[i, :, :, 3])
axes[1, 1].set_title('Lung other')

Output: enter image description here

Changing:

lut = torch.tensor([1, 2, 3, 0])
test_mask = mask_four[31]
test_mask2 = lut[test_mask.long()]

fig, axes = plt.subplots(2, 2, figsize=(10, 8))
axes[0, 0].imshow(test_mask2[:, :, 0])
axes[0, 1].imshow(test_mask2[:, :, 1])
axes[1, 0].imshow(test_mask2[:, :, 2])
axes[1, 1].imshow(test_mask2[:, :, 3])

Output 2: enter image description here

This is test_mask:

tensor([[[1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         ...,
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.]],

        [[1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         ...,
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.]],

        [[1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         ...,
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.]],

        ...,

This is test_mask2:

tensor([[[2, 1, 1, 1],
         [2, 1, 1, 1],
         [2, 1, 1, 1],
         ...,
         [2, 1, 1, 1],
         [2, 1, 1, 1],
         [2, 1, 1, 1]],

        [[2, 1, 1, 1],
         [2, 1, 1, 1],
         [2, 1, 1, 1],
         ...,
         [2, 1, 1, 1],
         [2, 1, 1, 1],
         [2, 1, 1, 1]],

        [[2, 1, 1, 1],
         [2, 1, 1, 1],
         [2, 1, 1, 1],
         ...,
         [2, 1, 1, 1],
         [2, 1, 1, 1],
         [2, 1, 1, 1]],

        ...,

Solution

  • I think what you are looking for is applying a lookup table to the target labels.
    This can be done fairly easily:

    In []: orig_target = torch.randint(0, 3, (3,4)); orig_target
    Out[]:
    tensor([[2, 1, 2, 1],
            [2, 2, 1, 2],
            [1, 2, 0, 0]])
    
    # define the lookup table: 0 -> 2, 1 -> 0, 2 -> 1:
    In []: lut = torch.tensor([2, 0, 1])
    
    # apply the lookup table
    In []: lut[orig_target]
    Out[]:
    tensor([[1, 0, 1, 0],
            [1, 1, 0, 1],
            [0, 1, 2, 2]])
    
    

    Update:

    if you want to peemute the channels of the mask:

    per_target = orig_target[..., lut]