Search code examples
pythonnetworkx

Drawing a graph network in 3D


Suppose I am a given a directed graph in python:

V = [ 1, 2, 3, 4, 5]
E = { 
1 : [ 2, 3, 4]
2: [ 1, 2, 3]
3 : [1, 4, 5]
4: [5]
5: [1, 3] }
c = [ 81, 23, 43, 22, 100]

V and E represent the vertex and edge sets of the graph as a list and dictionary respectively. And c is a cost function on the vertex set i.e. c(1) = 81 , c(2) = 23 etc. Now I want to visualize the graph represented by (V,E) which can be done easily using the networkx package in 2 dimensions, BUT additionally I want to plot the vertices of this graph on varying z axis (instead of only on the xy plane) so that the 'height' of each vertex on the z axis equals its cost.

How can I do so?


Solution

  • Here's something to get you started.

    As you point out, networkx.draw will plot a 2-D graph.

    enter image description here

    The point locations can be extracted from the networkx.spring_layout function (which is the default layout of the points when draw is called).

    import networkx
    import numpy as np
    import itertools as it
    import matplotlib.pyplot as plt
    
    g = networkx.Graph(E)
    
    pts = networkx.spring_layout(g)
    

    Example (these points will move around due to the random nature of the networkx layout functions):

    {1: array([-0.35707887, -0.02227005]),
     2: array([-0.35348547, -1.        ]),
     3: array([ 0.26374451, -0.24148893]),
     4: array([0.56832209, 0.50997504]),
     5: array([-0.12150225,  0.75378394])}
    

    Extract the points as follows:

    pts_array = np.array([pt[1] for pt in sorted(list(pts.items()))])
    

    and the connections between vertices can be obtained from your E dict like this:

    ix_connections = [a for a in it.chain.from_iterable([list(zip(it.repeat(b[0]), b[1])) for b in E.items()])]
    
    [(1, 2), (1, 3), (1, 4), (2, 1), (2, 2), (2, 3), (3, 1), (3, 4), (3, 5), (4, 5), (5, 1), (5, 3)]
    

    Thus, we can manually plot the graph like this:

    plt.figure()
    ax=plt.gca()
    ax.plot(*pts_array.T, 'o')
    for i,j in ix_connections:
      ax.plot(*pts_array[[i-1,j-1],:].T,'k')
    

    2d plot of graph

    Now, we can plot this with heights using plt.scatter and a 3-d projection axis:

    pts_array_z = np.hstack((np.array([pt[1] for pt in sorted(list(pts.items()))]), 
                             np.array(c)[:,np.newaxis]))
    fig=plt.figure()
    ax=fig.add_subplot(projection='3d')
    ax.plot(*pts_array_z.T, 'o')
    for i,j in ix_connections:
      ax.plot(*pts_array_z[[i-1,j-1],:].T,'k')
    

    3d plot of graph

    NB: These figures doesn't include the loop on vertex 2, which you'll need to draw yourself. Also, the graphs drawn here are undirected, arrowheads will need to be included for your directed edges