Search code examples
pythongraph-theoryigraphgraph-toolnetworkit

Time-efficient way to find connected spheres paths in Python


I have written a code to find connected spheres paths using NetworkX library in Python. For doing so, I need to find distances between the spheres before using the graph. This part of the code (calculation section (the numba function) --> finding distances and connections) led to memory leaks when using arrays in parallel scheme by numba (I had this problem when using np.linalg or scipy.spatial.distance.cdist, too). So, I wrote a non-parallel numba code using lists to do so. Now, it is memory-friendly but consumes a much time to calculate these distances (it consumes just ~10-20% of 16GB memory and ~30-40% of each CPU cores of my 4-cores CPU machine). For example, when I was testing on ~12000 data volume, it took less than one second for each of the calculation section and the NetworkX graph creation and for ~550000 data volume, it took around 25 minutes for calculation section (numba part) and 7 seconds for graph creation and getting the output list.

import numpy as np
import numba as nb
import networkx as nx


radii = np.load('rad_dist_12000.npy')
poss = np.load('pos_dist_12000.npy')


@nb.njit("(Tuple([float64[:, ::1], float64[:, ::1]]))(float64[::1], float64[:, ::1])", parallel=True)
def distances_numba_parallel(radii, poss):
    radii_arr = np.zeros((radii.shape[0], radii.shape[0]), dtype=np.float64)
    poss_arr = np.zeros((poss.shape[0], poss.shape[0]), dtype=np.float64)
    for i in nb.prange(radii.shape[0] - 1):
        for j in range(i+1, radii.shape[0]):
            radii_arr[i, j] = radii[i] + radii[j]
            poss_arr[i, j] = ((poss[i, 0] - poss[j, 0]) ** 2 + (poss[i, 1] - poss[j, 1]) ** 2 + (poss[i, 2] - poss[j, 2]) ** 2) ** 0.5
    return radii_arr, poss_arr


@nb.njit("(List(UniTuple(int64, 2)))(float64[::1], float64[:, ::1])")
def distances_numba_non_parallel(radii, poss):
    connections = []
    for i in range(radii.shape[0] - 1):
        connections.append((i, i))
        for j in range(i+1, radii.shape[0]):
            radii_arr_ij = radii[i] + radii[j]
            poss_arr_ij = ((poss[i, 0] - poss[j, 0]) ** 2 + (poss[i, 1] - poss[j, 1]) ** 2 + (poss[i, 2] - poss[j, 2]) ** 2) ** 0.5
            if poss_arr_ij <= radii_arr_ij:
                connections.append((i, j))
    return connections


def connected_spheres_path(radii, poss):
    
    # in parallel mode
    # maximum_distances, distances = distances_numba_parallel(radii, poss)
    # connections = distances <= maximum_distances
    # connections[np.tril_indices_from(connections, -1)] = False
    
    # in non-parallel mode
    connections = distances_numba_non_parallel(radii, poss)

    G = nx.Graph(connections)
    return list(nx.connected_components(G))

My datasets will contain maximum of 10 millions spheres (data are positions and radii), mostly, up to 1 millions; As it is mentioned above, the most part of the consumed time is related to the calculation section. I have little experience using graphs and don't know if (and how) it can be handled much faster using all CPU cores or RAM capacity (max 12GB) or if it can be calculated internally (I doubt that it is needed to calculate and find the connected spheres separately before using graphs) using other Python libraries such as graph-tool, igraph, and netwrokit to do all the process in C or C++ in an efficient way.
I would be grateful for any suggested answer that can make my code faster for large data volumes (performance is the first priority; if much memory capacities are needed for large data volumes, mentioning (some benchmarks) its amounts will be helpful).


Update:

Since just using trees will not be helpful enough to improve the performance, I have written an advanced optimized code to improve the calculation section speed by combining tree-based algorithms and numba jitting.
Now, I am curious if it can be calculated internally (calculation section is an integral part and basic need for such graphing) by other Python libraries such as graph-tool, igraph, and netwrokit to do all the process in C or C++ in an efficient way.


Data

radii: 12000, 50000, 550000
poss: 12000, 50000, 550000


Solution

  • If you are computing the pairwise distance between all points, that's N^2 calculations, which will take a very long time for sufficiently many data points.

    If you can place an upper bound on the distance you need to consider for any two points, then there are some nice data structures for finding pairs of neighbors in a set of points. If you already have scipy installed, then the most convenient structure to reach for is the KDTree (or the optimized version, cKDTree). (Read more here.)

    The basic recipe is:

    • Load your point set into the KDTree.
    • Ask the KDTree for all pairs of points which are within some maximum distance from each other.
    • Calculate the actual distances between each of the returned points.
    • Compare those distances with the summed radii associated with the point pair. Drop the pairs whose distances are too large.

    Finally, you need to determine the clusters of spheres. Your question mentions "paths", but in your example code you're only concerned with connected components. Of course you can use networkx or graph-tool for that, but maybe that's overkill.

    If connected components are all you need, then you don't even need a proper graph data structure. You just need a way to find the groups of linked nodes, without maintaining the specific connections that linked them. Again, scipy has a nice tool: DisjointSet. (Read more here.)

    Here is a complete example. The execution time depends on not only the number of points, but how "dense" they are. I tried some reasonable (I think) test data with 1M points, which took 24 seconds to process on my laptop.

    Your example data (the largest of the sets provided above) takes longer: about 45 seconds. The KDTree finds 312M pairs of points to consider, of which fewer than 1M are actually valid connections.

    import numpy as np
    from scipy.spatial import cKDTree
    from scipy.cluster.hierarchy import DisjointSet
    
    ## Example data (2D)
    ## N = 1000
    # D = 2
    # max_point = 1000
    # min_radius = 10
    # max_radius = 20
    # points = np.random.randint(0, max_point, size=(N, D))
    # radii = np.random.randint(min_radius, max_radius+1, size=N)
    
    ## Example data (3D)
    # N = 1_000_000
    # D = 3
    # max_point = 3000
    # min_radius = 10
    # max_radius = 20
    # points = np.random.randint(0, max_point, size=(N, D))
    # radii = np.random.randint(min_radius, max_radius+1, size=N)
    
    
    # Question data (3D)
    points = np.load('b (556024).npy')
    radii = np.load('a (556024).npy')
    N = len(points)
    
    # Load into a KD tree and extract all pairs which could possibly be linked
    # (using the maximum radius as the upper bound of the search distance.)
    kd = cKDTree(points)
    pairs = kd.query_pairs(2 * radii.max(), output_type='ndarray')
    
    def filter_pairs(pairs):
        # Calculate the distance between each pair of points
        vectors = points[pairs[:, 1]] - points[pairs[:, 0]]
        distances = np.linalg.norm(vectors, axis=1)
    
        # Drop the pairs whose summed radii aren't large enough
        # to span the distance between the points.
        thresholds = radii[pairs].sum(axis=1)
        return pairs[distances <= thresholds]
    
    # We could do this in one big step
    # ...but that might require lots of RAM.
    # It's cheaper to do it in big chunks, in a loop.
    fp = []
    CHUNK = 1_000_000
    for i in range(0, len(pairs), CHUNK):
        fp.append(filter_pairs(pairs[i:i+CHUNK]))
    filtered_pairs = np.concatenate(fp)
    
    # Load the pairs into a DisjointSet (a.k.a. UnionFind)
    # data structure and extract the groups.
    ds = DisjointSet(range(N))
    for u, v in filtered_pairs:
        ds.merge(u, v)
    connected_sets = list(ds.subsets())
    
    print(f"Found {len(connected_sets)} sets of circles/spheres")
    

    Just for fun, here's a visualization of the 2D test data:

    from bokeh.plotting import output_notebook, figure, show
    output_notebook()
    
    p = figure()
    p.circle(*points.T, radius=radii, fill_alpha=0.25)
    p.segment(*points[filtered_pairs[:, 0]].T,
              *points[filtered_pairs[:, 1]].T,
              line_color='red')
    show(p)
    

    enter image description here