Search code examples
computer-visiongraph-theoryimage-segmentationringvoxel

Detect rings/circuits of connected voxels


I have a skeletonized voxel structure that looks like this: example voxel rings

The actual structure is significantly larger than this exampleIs there any way to find the closed rings in the structure? I tried converting it to a graph and using graph based approaches but they all have the problem that a graph has no spatial information of node position and hence a graph can have multiple rings that are homologous.

It is not possible to find all the rings and then filter out the ones of interest since the graph is just too large. The size of the rings varies significantly.

Thanks for your help and contribution!

Any language approaches and pseudo-code are welcomed though I work mostly in Python and Matlab.


EDIT:

No the graph is not planar. The problem with the Graph cycle base is the same as with other simple graph based approaches. The graph lacks any spatial information and different spatial configurations can have the same cycle base, hence the cycle base does not necessarily correspond to the cycles or holes in the graph.

Here is the adjacency matrix in sparse format:

NodeID1 NodeID2 Weight

Pastebin with adjacency matrix

And here are the corresponding X,Y,Z coordinates for the Nodes of the graph:

X Y Z

Pastebin with node coordinates

(The actual structure is significantly larger than this example)


Solution

  • First I reduce the size of the problem considerably by contracting neighbouring nodes of degree 2 into hypernodes: each simple chain in the graph is substituted with a single node.

    Then I find the cycle basis, for which the maximum cost of the cycles in the basis set is minimal.

    For the central part of the network, the solution can easily be plotted as it is planar:

    enter image description here

    For some reason, I fail to correctly identify the cycle basis but I think the following should definitely get you started and maybe somebody else can chime in.

    Recover data from posted image (as OP wouldn't provide some real data)

    import numpy as np
    import matplotlib.pyplot as plt
    from skimage.morphology import medial_axis, binary_closing
    from matplotlib.patches import Path, PathPatch
    import itertools
    import networkx as nx
    
    img = plt.imread("tissue_skeleton_crop.jpg")
    # plt.hist(np.mean(img, axis=-1).ravel(), bins=255) # find a good cutoff
    bw = np.mean(img, axis=-1) < 200
    # plt.imshow(bw, cmap='gray')
    closed = binary_closing(bw, selem=np.ones((50,50))) # connect disconnected segments
    # plt.imshow(closed, cmap='gray')
    skeleton = medial_axis(closed)
    
    fig, ax = plt.subplots(1,1)
    ax.imshow(skeleton, cmap='gray')
    ax.set_xticks([])
    ax.set_yticks([])
    

    enter image description here

    def img_to_graph(binary_img, allowed_steps):
        """
        Arguments:
        ----------
        binary_img    -- 2D boolean array marking the position of nodes
        allowed_steps -- list of allowed steps; e.g. [(0, 1), (1, 1)] signifies that
                         from node with position (i, j) nodes at position (i, j+1)
                         and (i+1, j+1) are accessible,
    
        Returns:
        --------
        g             -- networkx.Graph() instance
        pos_to_idx    -- dict mapping (i, j) position to node idx (for testing if path exists)
        idx_to_pos    -- dict mapping node idx to (i, j) position (for plotting)
        """
    
        # map array indices to node indices and vice versa
        node_idx = range(np.sum(binary_img))
        node_pos = zip(*np.where(np.rot90(binary_img, 3)))
        pos_to_idx = dict(zip(node_pos, node_idx))
    
        # create graph
        g = nx.Graph()
        for (i, j) in node_pos:
            for (delta_i, delta_j) in allowed_steps: # try to step in all allowed directions
                if (i+delta_i, j+delta_j) in pos_to_idx: # i.e. target node also exists
                    g.add_edge(pos_to_idx[(i,j)], pos_to_idx[(i+delta_i, j+delta_j)])
    
        idx_to_pos = dict(zip(node_idx, node_pos))
    
        return g, idx_to_pos, pos_to_idx
    
    allowed_steps = set(itertools.product((-1, 0, 1), repeat=2)) - set([(0,0)])
    g, idx_to_pos, pos_to_idx = img_to_graph(skeleton, allowed_steps)
    
    fig, ax = plt.subplots(1,1)
    nx.draw(g, pos=idx_to_pos, node_size=1, ax=ax)
    

    enter image description here

    NB: These are not red lines, these are lots of red dots corresponding to nodes in the graph.

    Contract Graph

    def contract(g):
        """
        Contract chains of neighbouring vertices with degree 2 into one hypernode.
    
        Arguments:
        ----------
        g -- networkx.Graph or networkx.DiGraph instance
    
        Returns:
        --------
        h -- networkx.Graph or networkx.DiGraph instance
            the contracted graph
    
        hypernode_to_nodes -- dict: int hypernode -> [v1, v2, ..., vn]
            dictionary mapping hypernodes to nodes
    
        """
    
        # create subgraph of all nodes with degree 2
        is_chain = [node for node, degree in g.degree() if degree == 2]
        chains = g.subgraph(is_chain)
    
        # contract connected components (which should be chains of variable length) into single node
        components = list(nx.components.connected_component_subgraphs(chains))
        hypernode = g.number_of_nodes()
        hypernodes = []
        hyperedges = []
        hypernode_to_nodes = dict()
        false_alarms = []
        for component in components:
            if component.number_of_nodes() > 1:
    
                hypernodes.append(hypernode)
                vs = [node for node in component.nodes()]
                hypernode_to_nodes[hypernode] = vs
    
                # create new edges from the neighbours of the chain ends to the hypernode
                component_edges = [e for e in component.edges()]
                for v, w in [e for e in g.edges(vs) if not ((e in component_edges) or (e[::-1] in component_edges))]:
                    if v in component:
                        hyperedges.append([hypernode, w])
                    else:
                        hyperedges.append([v, hypernode])
    
                hypernode += 1
    
            else: # nothing to collapse as there is only a single node in component:
                false_alarms.extend([node for node in component.nodes()])
    
        # initialise new graph with all other nodes
        not_chain = [node for node in g.nodes() if not node in is_chain]
        h = g.subgraph(not_chain + false_alarms)
        h.add_nodes_from(hypernodes)
        h.add_edges_from(hyperedges)
    
        return h, hypernode_to_nodes
    
    h, hypernode_to_nodes = contract(g)
    
    # set position of hypernode to position of centre of chain
    for hypernode, nodes in hypernode_to_nodes.items():
        chain = g.subgraph(nodes)
        first, last = [node for node, degree in chain.degree() if degree==1]
        path = nx.shortest_path(chain, first, last)
        centre = path[len(path)/2]
        idx_to_pos[hypernode] = idx_to_pos[centre]
    
    fig, ax = plt.subplots(1,1)
    nx.draw(h, pos=idx_to_pos, node_size=20, ax=ax)
    

    enter image description here

    Find cycle basis

    cycle_basis = nx.cycle_basis(h)
    
    fig, ax = plt.subplots(1,1)
    nx.draw(h, pos=idx_to_pos, node_size=10, ax=ax)
    for cycle in cycle_basis:
        vertices = [idx_to_pos[idx] for idx in cycle]
        path = Path(vertices)
        ax.add_artist(PathPatch(path, facecolor=np.random.rand(3)))
    

    TODO:

    Find the correct cycle basis (I might be confused what the cycle basis is or networkx might have a bug).

    EDIT

    Holy crap, this was a tour-de-force. I should have never delved into this rabbit hole.

    enter image description here

    So the idea is now that we want to find the cycle basis for which the maximum cost for the cycles in the basis is minimal. We set the cost of a cycle to its length in edges, but one could imagine other cost functions. To do so, we find an initial cycle basis, and then we combine cycles in the basis until we find the set of cycles with the desired property.

    def find_holes(graph, cost_function):
        """
        Find the cycle basis, that minimises the maximum individual cost of the cycles in the basis set.
        """
    
        # get cycle basis
        cycles = nx.cycle_basis(graph)
    
        # find new basis set that minimises maximum cost
        old_basis = set()
        new_basis = set(frozenset(cycle) for cycle in cycles) # only frozensets are hashable
        while new_basis != old_basis:
            old_basis = new_basis
            for cycle_a, cycle_b in itertools.combinations(old_basis, 2):
                if len(frozenset.union(cycle_a, cycle_b)) >= 2: # maybe should check if they share an edge instead
                    cycle_c = _symmetric_difference(graph, cycle_a, cycle_b)
                    new_basis = new_basis.union([cycle_c])
            new_basis = _select_cycles(new_basis, cost_function)
    
        ordered_cycles = [order_nodes_in_cycle(graph, nodes) for nodes in new_basis]
        return ordered_cycles
    
    def _symmetric_difference(graph, cycle_a, cycle_b):
        # get edges
        edges_a = list(graph.subgraph(cycle_a).edges())
        edges_b = list(graph.subgraph(cycle_b).edges())
    
        # also get reverse edges as graph undirected
        edges_a += [e[::-1] for e in edges_a]
        edges_b += [e[::-1] for e in edges_b]
    
        # find edges that are in either but not in both
        edges_c = set(edges_a) ^ set(edges_b)
    
        cycle_c = frozenset(nx.Graph(list(edges_c)).nodes())
        return cycle_c
    
    def _select_cycles(cycles, cost_function):
        """
        Select cover of nodes with cycles that minimises the maximum cost
        associated with all cycles in the cover.
        """
        cycles = list(cycles)
        costs = [cost_function(cycle) for cycle in cycles]
        order = np.argsort(costs)
    
        nodes = frozenset.union(*cycles)
        covered = set()
        basis = []
    
        # greedy; start with lowest cost
        for ii in order:
            cycle = cycles[ii]
            if cycle <= covered:
                pass
            else:
                basis.append(cycle)
                covered |= cycle
                if covered == nodes:
                    break
    
        return set(basis)
    
    def _get_cost(cycle, hypernode_to_nodes):
        cost = 0
        for node in cycle:
            if node in hypernode_to_nodes:
                cost += len(hypernode_to_nodes[node])
            else:
                cost += 1
        return cost
    
    def _order_nodes_in_cycle(graph, nodes):
        order, = nx.cycle_basis(graph.subgraph(nodes))
        return order
    
    holes = find_holes(h, cost_function=partial(_get_cost, hypernode_to_nodes=hypernode_to_nodes))
    
    fig, ax = plt.subplots(1,1)
    nx.draw(h, pos=idx_to_pos, node_size=10, ax=ax)
    for ii, hole in enumerate(holes):
        if (len(hole) > 3):
            vertices = np.array([idx_to_pos[idx] for idx in hole])
            path = Path(vertices)
            ax.add_artist(PathPatch(path, facecolor=np.random.rand(3)))
            xmin, ymin = np.min(vertices, axis=0)
            xmax, ymax = np.max(vertices, axis=0)
            x = xmin + (xmax-xmin) / 2.
            y = ymin + (ymax-ymin) / 2.
            # ax.text(x, y, str(ii))