Search code examples
pythondeep-learningpytorchsemantic-segmentationunet-neural-network

Converting UNet model outputs from logits to segmentation masks


I've been recently working on segmentation task on certain CT scans. I decided to use Python as the language, ResidualUNet model architecture implemented in MONAI, evaluated by DiceLoss. Everything went smoothly until the interference of a trained model. This ResidualUNet does not have a Softmax or Sigmoid layer, so it outputs raw logits, being floats ranging from around -30 to 12. How could I properly convert those logits to obtain probabilites of each pixel belonging to each class? The input, in notation of BCHWD, is 1x1x256x256x256, being of course the image itself, while the output is 1x9x256x256x256, each channel being a mask for different class. The code for inference looks more less like this:

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

model = Unet(
    spatial_dims = 3,
    in_channels = 1,
    out_channels = 9,
    channels = (8, 16, 32, 64, 128, 256),
    strides = (2,2,2,2,2),
    num_res_units = 4,
    norm = Norm.INSTANCE
).to(device)

model.load_state_dict(torch.load(PATH_TO_SAVED_MODEL_OBJECT)["model_state_dict"])

inputs, labels = next(iter(validation_dataloader))   # obtaining only one image - batch_size=1
inferer = SimpleInferer()   # simple inferer from monai.inferers
inputs = inputs.to(device)  # pass to GPU
labels = labels.to(device)  # pass to GPU
pred = inferer(inputs = inputs, network=model)
pred = pred.detach().cpu().numpy() # conversion to numpy array for viewing purposes

Thanks in advance for your assistance.


Solution

  • If you need a multiclass classification such that each pixel can belong to multiple classes, you can use sigmoid i.e., prob = torch.sigmoid(pred) which will give you the probability of each pixel belonging to each class independently, and if you need a single-class classification, you can use softmax i.e., prob = torch.softmax(pred, dim=1). Both of these will convert the logits to valid probability distributions.