Is there a way to split the output of a scipy.sparse.csgraph.minimum_spanning_tree operation by dropping the greatest edge weight value in the tree? I am trying to get access to each of the subtrees that would result by dropping the greatest edge weight if that edge was not an outer edge of the minimum spanning tree.
Using the SciPy docs example:
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import minimum_spanning_tree
X = csr_matrix([[0, 8, 0, 3],
[0, 0, 2, 5],
[0, 0, 0, 6],
[0, 0, 0, 0]])
Tcsr = minimum_spanning_tree(X)
# print(Tcsr)
# (0,3) 3.0
# (3,1) 5.0
# (1,2) 2.0
What is the best way to drop the middle value in the minimum spanning tree above and have access to the other two edges separately? I am trying to do this on large graphs and trying to avoid large Python loops where possible. Thanks.
I had the same problem and managed to figure out a solution using only scipy
. All this does is take the MST, locate the maximum weighted edge, delete it (i.e. zero it out), and then use the connected_components
method to figure out which nodes remain connected.
Here is the complete script with comments:
import numpy as np
from scipy.sparse.csgraph import minimum_spanning_tree, connected_components
from scipy.sparse import csr_matrix
# Create a random "distance" matrix.
# Select only the upper triangle since the distance matrix array would be symmetrical.
a = np.random.rand(5,5)
a = np.triu(a)
# Create the minimum spanning tree.
mst = minimum_spanning_tree(csr_matrix(a))
mst = mst.toarray()
# Get the index of the maximum value.
# `argmax` returns the index of the _flattened_ array;
# `unravel_index` converts it back.
idx = np.unravel_index(mst.argmax(), mst.shape)
# Clear out the maximum value to split the tree.
mst[idx] = 0
# Label connected components.
num_graphs, labels = connected_components(mst, directed=False)
# We should have two trees.
assert(num_graphs == 2)
# Use indices as node ids and group them according to their graph.
results = [[] for i in range(max(labels) + 1)]
for idx, label in enumerate(labels):
results[label].append(idx)
print(results)
This will yield something like:
[[0, 1, 4], [2, 3]]