Search code examples
pythonpytorchtorch

Convert torch of 9 channel to image of 3 channel (or 1) to display it


I have a tensor composed of 9 channel [9, 224, 224], (which is result of prediction. How could I convert to 3 channel as an image so that I could display it.

predicted =predicted.cpu() 
label=predicted [0]
print(label.shape)

torch.Size([9, 224, 224])

Solution

  • I'm assuming that your (9, 224, 224) data is semantic segmentation maps. There are two possible variants:

    1. You have multi-class predictions
    # find normalized probabilities that sums up to 1 across the classes
    prediction = prediction.softmax(dim=0).cpu().numpy()
    
    # find the most probable class for each pixel
    labels = prediction.argmax(axis=0)
    
    # create a color pallete that maps class_idx to (R, G, B)
    palette = np.random.randint(0, 255, (prediction.shape[0], 3), np.uint8)
    color_mask = np.zeros((*r.shape, 3), np.uint8)
    # map each label to (RGB) color
    for idx, color in enumerate(palette):
        color_mask[r == idx] = color
    
    cv2.imshow('color_mask', color_mask)
    cv2.waitKey()
    

    Example of visualization:

    enter image description here enter image description here

    1. You have multi-label predictions. In that case you have 9 independent prediction masks
    # prediction = torch.sigmoid(prediction)  # in the case of logits
    # convert 0-1 probability maps into 0-255
    prediction = (prediction * 255).astype(np.uint8)
    # stack multiple probability maps horizontally
    prediction = np.hstack(prediction)
    

    Example: enter image description here

    Image taken from https://www.publicdomainpictures.net/en/view-image.php?image=24076