Search code examples
machine-learningdeep-learningsoftmaxactivation-functionvnet

Pytorch VNet final softmax activation layer for segmentation. Different channel dimensions to labels. How do I get prediction output?


I am trying to build a V-Net. When I pass the images to segment during training, the output has 2 channels after the softmax activation (as specified in the architecture in the attached image) but the label and input has 1. How do I convert this such that output is the segmented image? Do I just take one of the channels as the final output when training (e.g output = output[:, 0, :, :, :]) and the other channel would be background?

outputs = network(inputs)

batch_size = 32
outputs.shape: [32, 2, 64, 128, 128]
inputs.shape: [32, 1, 64, 128, 128]
labels.shape: [32, 1, 64, 128, 128]

Here is my Vnet forward pass:

def forward(self, x):
    # Initial input transition
    out = self.in_tr(x)

    # Downward transitions
    out, residual_0 = self.down_depth0(out)
    out, residual_1 = self.down_depth1(out)
    out, residual_2 = self.down_depth2(out)
    out, residual_3 = self.down_depth3(out)

    # Bottom layer
    out = self.up_depth4(out)

    # Upward transitions
    out = self.up_depth3(out, residual_3)        
    out = self.up_depth2(out, residual_2)
    out = self.up_depth1(out, residual_1)
    out = self.up_depth0(out, residual_0)

    # Pass to convert to 2 channels
    out = self.final_conv(out)
    
    # return softmax 
    out = F.softmax(out)
    
    return out [batch_size, 2, 64, 128, 128]

V Net architecture as described in (https://arxiv.org/pdf/1606.04797.pdf)


Solution

  • That paper has two outputs as they predict two classes:

    The network predictions, which consist of two volumes having the same resolution as the original input data, are processed through a soft-max layer which outputs the probability of each voxel to belong to foreground and to background.

    Therefore this is not an autoencoder, where your inputs are passed back through the model as outputs. They use a set of labels which distinguish between their pixels of interest (foreground) and other (background). You will need to change your data if you wish to use the V-net in this manner.

    It won't be as simple as designating a channel as output because this will be a classification task rather than a regression task. You will need annotated labels to work with this model architecture.