Search code examples
numpynetworkxgeospatialminimum-spanning-tree

How to extract graph edges in a specific order from networkx?


I am trying to find the shortest path that passes through a set of points that are pretty much aligned. This needs to work in all directions, so I can't just sort them by x or y values.

The solution I came up with was to make a graph using networkx, find the minimum spanning tree, and export the shortest path. I want to end up with a shapely LineString (not MultiLineString)

It turns out networkx plots exactly what I want, but I did not find a way to export the edges in the correct order to build a LineString, it always makes a mess. In order to do that, I need the vertices (or the nodes) in the correct order, as shown on the plot.

Can anyone help me ? Here's the example code below:

import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
from scipy.spatial.distance import pdist, squareform

points_line = np.array([[  0.46734317, 149.36430674],
       [  0.46734547, 149.36334419],
       [  0.46744031, 149.36238631],
       [  0.4676268 , 149.36144199],
       [  0.46734317, 149.36430674],
       [  0.46743343, 149.36526506],
       [  0.4676154 , 149.36621026],
       [  0.46788741, 149.36713358],
       [  0.46824692, 149.36802648],
       [  0.94443392, 150.40378967],
       [  0.4676268 , 149.36144199],
       [  1.55364459, 144.98283937],
       [  0.94443392, 150.40378967],
       [  0.68606383, 157.76059211],
       [  0.68606634, 157.76135963],
       [  1.55364459, 144.98283937],
       [  1.6347943 , 136.57997287],
       [  0.92188273, 132.24795534],
       [  0.92178416, 132.24715728],
       [  0.92175003, 132.24635387],
       [  0.90765426, 125.94462804],
       [  0.90769726, 125.94367903],
       [  0.68606634, 157.76135963],
       [  1.08596441, 167.35367069],
       [  0.66299718, 175.68436124],
       [  0.90769726, 125.94367903],
       [  1.8455184 , 115.86662374],
       [  1.22148527, 103.42945831],
       [  1.22148224, 103.42852062],
       [  1.22156706, 103.42758676],
       [  1.22173897, 103.42666495],
       [  1.89690364, 100.55965775],
       [  1.23628246,  92.47574962],
       [  1.23624942,  92.47487388],
       [  1.23629318,  92.47399861],
       [  1.23641341,  92.47313053],
       [  2.28757468,  86.74385772],
       [  2.28778124,  86.74296467],
       [  2.2880687 ,  86.74209429],
       [  2.28843466,  86.74125389],
       [  2.28887603,  86.74045053],
       [  2.2893891 ,  86.73969096],
       [  2.82731279,  86.01709145]])

# Create a graph
G = nx.Graph()

# Add nodes to the graph
for i, point in enumerate(points_line):
    G.add_node(i, pos=point)

# Calculate pairwise distances between all points
distances = pdist(points_line)
distance_matrix = squareform(distances)

# Add edges to the graph with weights based on distances
for i in range(len(points_line)):
    for j in range(i+1, len(points_line)):
        G.add_edge(i, j, weight=distance_matrix[i][j])

# Calculate the minimum spanning tree
mst = nx.minimum_spanning_tree(G)

# Extract edges from the minimum spanning tree
edges = list(mst.edges())

# Extract consecutive points from the edges to form the sequence of points
consecutive_points = []
for edge in edges:
    consecutive_points.append(points_line[edge[0]])
    consecutive_points.append(points_line[edge[1]])

# Convert the consecutive points to a numpy array
consecutive_points_array = np.array(consecutive_points)

# Find indices where the next point is not the same as the current point
indices = np.concatenate(([True], np.any(np.diff(consecutive_points_array, axis=0) != 0, axis=1)))

# Filter out the points based on the indices
filtered_points_array = consecutive_points_array[indices]

# Plot the original points and the minimum spanning tree
pos = nx.get_node_attributes(G, 'pos')
positions = dict(zip(mst.nodes, 'pos'))

plt.figure(figsize=(12, 8))
nx.draw_networkx_nodes(G, pos, node_size=20)
nx.draw_networkx_edges(mst, pos, edge_color='r', width=1)
plt.title("Minimum Spanning Tree")
plt.axis('equal')
plt.show()

I tried all of the exporting functions built in networkx to try and get to a numpy array, but it had repeating nodes and it did not work even after cleaning. Tried some solutions from LLMs but none worked.


Solution

  • A simple option would be to remove_edges_from the tree to exclude overlapping nodes (points with a weight/distance of 0), then reindex the points_line array with all_simple_paths (i.e, a list that holds nodes' indices) :

    from shapely import LineString
    
    mst.remove_edges_from(
        [(u, v) for (u, v, w) in mst.edges(data=True) if w["weight"] == 0]
    )
    
    # must be exactly two, e.g: 24, 42
    extremities = [n for n, d in mst.degree() if d == 1]
    
    ls = LineString(points_line[list(nx.all_simple_paths(mst, *extremities))].squeeze())
    

    enter image description here