Suppose I am a given a directed graph in python:
V = [ 1, 2, 3, 4, 5]
E = {
1 : [ 2, 3, 4]
2: [ 1, 2, 3]
3 : [1, 4, 5]
4: [5]
5: [1, 3] }
c = [ 81, 23, 43, 22, 100]
V and E represent the vertex and edge sets of the graph as a list and dictionary respectively. And c is a cost function on the vertex set i.e. c(1) = 81 , c(2) = 23 etc. Now I want to visualize the graph represented by (V,E) which can be done easily using the networkx package in 2 dimensions, BUT additionally I want to plot the vertices of this graph on varying z axis (instead of only on the xy plane) so that the 'height' of each vertex on the z axis equals its cost.
How can I do so?
Here's something to get you started.
As you point out, networkx.draw
will plot a 2-D graph.
The point locations can be extracted from the networkx.spring_layout
function (which is the default layout of the points when draw
is called).
import networkx
import numpy as np
import itertools as it
import matplotlib.pyplot as plt
g = networkx.Graph(E)
pts = networkx.spring_layout(g)
Example (these points will move around due to the random nature of the networkx
layout functions):
{1: array([-0.35707887, -0.02227005]),
2: array([-0.35348547, -1. ]),
3: array([ 0.26374451, -0.24148893]),
4: array([0.56832209, 0.50997504]),
5: array([-0.12150225, 0.75378394])}
Extract the points as follows:
pts_array = np.array([pt[1] for pt in sorted(list(pts.items()))])
and the connections between vertices can be obtained from your E
dict like this:
ix_connections = [a for a in it.chain.from_iterable([list(zip(it.repeat(b[0]), b[1])) for b in E.items()])]
[(1, 2), (1, 3), (1, 4), (2, 1), (2, 2), (2, 3), (3, 1), (3, 4), (3, 5), (4, 5), (5, 1), (5, 3)]
Thus, we can manually plot the graph like this:
plt.figure()
ax=plt.gca()
ax.plot(*pts_array.T, 'o')
for i,j in ix_connections:
ax.plot(*pts_array[[i-1,j-1],:].T,'k')
Now, we can plot this with heights using plt.scatter
and a 3-d projection axis:
pts_array_z = np.hstack((np.array([pt[1] for pt in sorted(list(pts.items()))]),
np.array(c)[:,np.newaxis]))
fig=plt.figure()
ax=fig.add_subplot(projection='3d')
ax.plot(*pts_array_z.T, 'o')
for i,j in ix_connections:
ax.plot(*pts_array_z[[i-1,j-1],:].T,'k')
NB: These figures doesn't include the loop on vertex 2, which you'll need to draw yourself. Also, the graphs drawn here are undirected, arrowheads will need to be included for your directed edges