Search code examples
pythonalgorithmscipynetworkxminimum-spanning-tree

Efficiently Find Approximate Minimum Spanning Tree of a Large Graph


I have a large number of nodes, ~25000, each with a 3D position.

I want to produce a sparsely connected graph, with edges given by the distance between nodes, for use in a GNN.

Algorithms to find the Minimum Spanning Tree (MST) typically rely on starting first with the fully connected graph, then removing nodes to find the MST. This is very slow for large graphs.

To try to speed this up, I tried using scipy's sparse_distance_matrix to limit the radius of initial connections to the maximum nearest neighbor distance, but this results in multiple connected components for some graphs.

(Example for a smaller graph of this method not working:)

Example for a small graph

Here's what I've tried:

import numpy as np
from scipy.sparse.csgraph import minimum_spanning_tree
from scipy.spatial import KDTree
from scipy.spatial import distance_matrix

# Create a random set of 25000 node positions
positions = np.random.normal(loc=500000, scale=10000, size=(25000, 3))

# Method 1: Full MST Search (Too Slow)
dist_matrix = distance_matrix(positions, positions)
mst = minimum_spanning_tree(dist_matrix)


# Method 2: Start from sparse matrix (Results in multiple connected components)
kdtree = KDTree(positions)

# Find the maximum nearest neighbour distance for the graph
distances, _ = kdtree.query(positions, k=2)
max_neighbour_distance = np.max(distances[:, 1])
max_neighbour_distance = np.ceil(max_neighbour_distance) # Round up to avoid floating point errors in MST search


# Threshold the distance matrix by this distance
sparse_dist_matrix = kdtree.sparse_distance_matrix(kdtree, max_neighbour_distance, output_type="coo_matrix")

mst = minimum_spanning_tree(sparse_dist_matrix)

G = nx.from_scipy_sparse_array(mst)

I don't need the true minimum spanning tree, just for the graph to be connected with as few as possible edges to speed up GNN performance. Even the sparse method is too slow for some graphs.

I considered a method based on https://www.cs.umd.edu/~mount/Papers/soda16-emst.pdf but it looks hard to implement, i.e. scipy doesn't have quadtrees.

Converting the fully distance matrix to a networkx graph then using their implementation of Boruvka's algorithm is even slower, it's not intended for large graphs. Adding a multiplier to max_neighbor_distance would help ensure there's only one connected component, but would also increase processing time and won't always be enough.


Solution

  • The Euclidean distance minimum spanning tree for a set of points is a subgraph of the Delaunay triangulation, which has a linear number of edges.

    Scipy has a method for calculating the Delaunay triangulation efficiently.