Search code examples
c++computer-visioncaffeimage-segmentationmarkov-random-fields

Probability as input to Markov random field (MRF): how to refine the cmex code?


I am very new with MRF and not that much good at programming. I have obtained probability map from semantic segmentation using a CNN, I have to optimize the segmentation by using Markov Random Fields (MRF). I download the code provided by Shai Bagon in this link GCmex. Energy minimization is performed based on either alpha expansion or swap.

I compiled the code by mex and I need to refine the Unary and pair-wise energy minimization functions. I have a stack of images and need to extract the 6-neighborhood grid and include the refined neighboring in the pair-wise function.

The input to the Unary function is the probability map which is a stack with size (256,256,4) for 4 different classes: enter image description here

My questions are: Has someone refined the code according to the defining different energy function 1) I wanna change Unary and pair-wise formulation). Which functions and which parts of code should be refined and recompiled again?

2) How to change the w_i,j? it is calculate based on intensity difference, here we have only probabilities, Is it the difference of probabilities of two adjacent voxels?

I really appreciate your help. Thanks


Solution

  • You have 60 slices of 256x256 pix (tot ~4G voxels), that is slices is a 256-by-256-by-60 array. Once you feed slices into your net (one by one or in batches - whatever works best for you) you have prob probability of size 256-by-256-by-60-by-4.
    I suggest you use third constructor of GCMex to construct your graph for optimization.
    To do so, you first need to define a sparse graph. Use sparse_adj_matrix:

    [ii jj] = sparse_adj_matrix([256 256 60], 1, 1);  % 6-connect 3D grid
    n = prod([256 256 60]);  % num voxels
    wij = exp(-((slices(ii)-slices(jj)).^2)/(2*sig2));  % -|Ii-Ij|^2/2\sig^2
    W = sparse(ii, jj, wij, n, n);  % sparse grid graph
    

    Once you have the graph, it's all down hill from here:

    Dc = -reallog(reshape(prob, n, 4)).';  %' unary/data term 
    lambda = 2;  % relative weight of the smoothness term
    gch = GraphCut('open', Dc, lambda*(ones(4)-eye(4)), W);  % construct the graph
    [gch L] = GraphCut('expand', gch);  % minimize using "expand" method
    gch = GraphCut('close', gch);  % do not forget to de-allocate
    

    To see the output labels, you need to reshape

    output = reshape(L, size(slices));
    

    PS,
    If your spatial distance between slices is larger than the gap between neighboring voxels in the same slice, you might need to use different sig2 for ii and jj that are in the same slice and for ii and jj that are on different slices. This requires a bit of an effort.