I wonder if it's possible to change class labels in a segmentation mask. Suppose I have masks, where the classes order is like that: 0 - class A, 1 - class B, 2 - background. Is there a way to change it to: 0 - background, 1 - class A, 2 - class B?
This is my code: Now my mask is like: 0 - background, 1 - ground glass, 2 - consolidation, 3 - lung other. What I want: 0 - ground glass, 1 - consolidation, 2 - lung other, 3 - background.
Before changing:
i = 31
fig, axes = plt.subplots(2, 2, figsize=(10, 8))
axes[0, 0].imshow(mask_four[i, :, :, 0])
axes[0, 0].set_title('Background')
axes[0, 1].imshow(mask_four[i, :, :, 1])
axes[0, 1].set_title('Ground glass')
axes[1, 0].imshow(mask_four[i, :, :, 2])
axes[1, 0].set_title('Consolidation')
axes[1, 1].imshow(mask_four[i, :, :, 3])
axes[1, 1].set_title('Lung other')
Changing:
lut = torch.tensor([1, 2, 3, 0])
test_mask = mask_four[31]
test_mask2 = lut[test_mask.long()]
fig, axes = plt.subplots(2, 2, figsize=(10, 8))
axes[0, 0].imshow(test_mask2[:, :, 0])
axes[0, 1].imshow(test_mask2[:, :, 1])
axes[1, 0].imshow(test_mask2[:, :, 2])
axes[1, 1].imshow(test_mask2[:, :, 3])
This is test_mask:
tensor([[[1., 0., 0., 0.],
[1., 0., 0., 0.],
[1., 0., 0., 0.],
...,
[1., 0., 0., 0.],
[1., 0., 0., 0.],
[1., 0., 0., 0.]],
[[1., 0., 0., 0.],
[1., 0., 0., 0.],
[1., 0., 0., 0.],
...,
[1., 0., 0., 0.],
[1., 0., 0., 0.],
[1., 0., 0., 0.]],
[[1., 0., 0., 0.],
[1., 0., 0., 0.],
[1., 0., 0., 0.],
...,
[1., 0., 0., 0.],
[1., 0., 0., 0.],
[1., 0., 0., 0.]],
...,
This is test_mask2:
tensor([[[2, 1, 1, 1],
[2, 1, 1, 1],
[2, 1, 1, 1],
...,
[2, 1, 1, 1],
[2, 1, 1, 1],
[2, 1, 1, 1]],
[[2, 1, 1, 1],
[2, 1, 1, 1],
[2, 1, 1, 1],
...,
[2, 1, 1, 1],
[2, 1, 1, 1],
[2, 1, 1, 1]],
[[2, 1, 1, 1],
[2, 1, 1, 1],
[2, 1, 1, 1],
...,
[2, 1, 1, 1],
[2, 1, 1, 1],
[2, 1, 1, 1]],
...,
I think what you are looking for is applying a lookup table to the target labels.
This can be done fairly easily:
In []: orig_target = torch.randint(0, 3, (3,4)); orig_target
Out[]:
tensor([[2, 1, 2, 1],
[2, 2, 1, 2],
[1, 2, 0, 0]])
# define the lookup table: 0 -> 2, 1 -> 0, 2 -> 1:
In []: lut = torch.tensor([2, 0, 1])
# apply the lookup table
In []: lut[orig_target]
Out[]:
tensor([[1, 0, 1, 0],
[1, 1, 0, 1],
[0, 1, 2, 2]])
Update:
if you want to peemute the channels of the mask:
per_target = orig_target[..., lut]