Search code examples
numpypytorchbounding-boxtorchvisionconnected-components

Python: How to extract connected components (bounding boxes) from 3D numpy / torch array?


I have binary segmentation masks for 3D arrays in NumPy/Torch. I would like to convert these to bounding boxes (a.k.a. connected components). As a disclaimer, each array can contain multiple connected components/bounding boxes, meaning I can't just take the min and max non-zero index values.

For concreteness, suppose I have a 3D array (I'll use 2D because 2D is easier to visualize) of binary values. I would like to know what the connected components are. For instance, I would like to take this segmentation mask:

>>> segmentation_mask
array([[1, 0, 0, 0, 0],
       [0, 1, 0, 0, 0],
       [1, 1, 1, 0, 0],
       [1, 1, 0, 1, 0],
       [1, 1, 0, 0, 1]], dtype=int32)

and convert it to the connected components, where the connected component have arbitrary labels i.e.

>>> connected_components
array([[1, 0, 0, 0, 0],
       [0, 2, 0, 0, 0],
       [2, 2, 2, 0, 0],
       [2, 2, 0, 3, 0],
       [2, 2, 0, 0, 4]], dtype=int32)

How do I do this with 3D arrays? I'm open to using Numpy, Scipy, Torchvision, opencv, any library.


Solution

  • This should work for any number of dimensions:

    import numpy as np                                                                
                                                                                      
    from scipy.sparse import csr_matrix                                               
    from scipy.sparse.csgraph import connected_components                             
                                                                                      
    segmentation_mask = np.array([[1, 0, 0, 0, 0],                                    
                                  [0, 1, 0, 0, 0],                                    
                                  [1, 1, 1, 0, 0],                                    
                                  [1, 1, 0, 1, 0],                                    
                                  [1, 1, 0, 0, 1]], dtype=np.int32)                   
                                                                                      
    row = []                                                                          
    col = []                                                                          
    segmentation_mask_reader = segmentation_mask.reshape(-1)                          
    n_nodes = len(segmentation_mask_reader)                                           
    for node in range(n_nodes):                                                       
        idxs = np.unravel_index(node, segmentation_mask.shape)                        
        if segmentation_mask[idxs] == 0:                                              
            col.append(n_nodes)                                                       
        else:                                                                         
            for i in range(len(idxs)):                                                
                if idxs[i] > 0:                                                       
                    new_idxs = list(idxs)                                             
                    new_idxs[i] -= 1                                                  
                    new_node = np.ravel_multi_index(new_idxs, segmentation_mask.shape)
                    if segmentation_mask_reader[new_node] != 0:                       
                        col.append(new_node)                                          
        while len(col) > len(row):                                                    
            row.append(node)                                                          
                                                                                      
    row = np.array(row, dtype=np.int32)                                               
    col = np.array(col, dtype=np.int32)                                               
    data = np.ones(len(row), dtype=np.int32)                                          
                                                                                      
    graph = csr_matrix((np.array(data), (np.array(row), np.array(col))),              
                       shape=(n_nodes+1, n_nodes+1))                                  
    n_components, labels = connected_components(csgraph=graph)                        
                                                                                      
    background_label = labels[-1]                                                     
    solution = np.zeros(segmentation_mask.shape, dtype=segmentation_mask.dtype)       
    solution_writer = solution.reshape(-1)                                            
    for node in range(n_nodes):                                                       
        label = labels[node]                                                          
        if label < background_label:                                                  
            solution_writer[node] = label+1                                           
        elif label > background_label:                                                
            solution_writer[node] = label                                             
                                                                                      
    print(solution)