Search code examples
pythonpython-3.xmatplotlibnetworkx

squared nodes in networkx with matplotlib


Before starting, I will point out that it might seem as if this question is a duplicate of this one , but the solution here simply doesnt compile in python3 with a current version of networkx. The set won't construct itself etc.

So I have a networkx graph, that I draw using matplotlib. Here is the code for it:

class Vert:
  
    # default constructor
    def __init__(self, name, size, edges):
        self.name = name
        self.size = size
        self.edges = edges


import networkx as nx
import matplotlib.pyplot as plt

nodes = []
nodes.append(Vert('A', 1, ['B', 'C']))
nodes.append(Vert('B', 3, ['D']))
nodes.append(Vert('C', 4, ['D']))
nodes.append(Vert('D', 7, []))
nodes.append(Vert('Y', 64, []))

G = nx.DiGraph()
for v  in nodes:
    G.add_node(v.name, s='v')
    for e in v.edges:
        G.add_edge(v.name, e)

node_sizes = [V.size * 100 for V in nodes]
shapes = set((aShape[1]["s"] for aShape in G.nodes(data = True)))

nx.draw(G, font_weight='bold', with_labels = True, node_size=node_sizes, node_shape= shapes)

#plt.savefig('plot.png', bbox_inches='tight')
plt.show()

I need some of the nodes to have a square or maybe triangled shape, how do I do this?


Solution

  • The code in the old answer fails to run because the syntax of add_path() has changed since the post was written. I edited the answer in the older question, but it won't show immediately since I don't yet have edit approval privileges.

    If you replace

    G.add_path([0,2,5])
    G.add_path([1,4,3,0])
    G.add_path([2,4,0,5])
    

    with

    nx.add_path(G, [0,2,5])
    nx.add_path(G, [1,4,3,0])
    nx.add_path(G, [2,4,0,5])
    

    then I believe it should run successfully.


    EDIT: In response to comment below, here is an example of working code that takes into account both shape and size. It isn't particularly clean or consistent in style, but it combines the method from the previous SO question with the questioner's data generation scheme.

    The key part is changing from using nx.draw() to separately drawing all parts of the graph using nx.draw_networkx_nodes(), nx.draw_networkx_edges(), and nx.draw_networkx_labels(). See the networkx drawing docs for full details. This change allows for drawing each set of nodes that has a different shape with a different call of nx.draw_networkx_nodes().

    I did some rather inelegant things to adjust the plot including adjusting plt.xlim, plt.ylim, and the spacing argument (k) of nx.layout.spring_layout().

    The code below gives the following plot:

    network_viz_w_shapes_sizes

    class Vert:
      
        # default constructor
        def __init__(self, name, size, edges):
            self.name = name
            self.size = size
            self.edges = edges
    
    
    import networkx as nx
    import matplotlib.pyplot as plt
    
    nodes = []
    nodes.append(Vert('A', 1, ['B', 'C']))
    nodes.append(Vert('B', 3, ['D']))
    nodes.append(Vert('C', 4, ['D']))
    nodes.append(Vert('D', 7, []))
    nodes.append(Vert('Y', 64, []))
    
    G = nx.DiGraph()
    
    for v  in nodes:
        # Assign 'v' shape to even nodes and square shape to odd nodes.
        if ord(v.name) % 2 == 0:
            G.add_node(v.name, size=v.size, shape='v')
        else:
            G.add_node(v.name, size=v.size, shape='s')
        for e in v.edges:
            G.add_edge(v.name, e)
    
    shapes = set((aShape[1]['shape'] for aShape in G.nodes(data = True)))
    pos = nx.layout.spring_layout(G, k=2) #Make k larger to space out nodes more.
    
    for shape in shapes:
        nodelist=[node[0] for node in filter(lambda x: x[1]['shape']==shape,G.nodes(data = True))]
        sizes = [100 * node[1]['size'] for node in filter(lambda x: x[1]['shape']==shape,G.nodes(data = True))]
        #...filter and draw the subset of nodes with the same symbol in the positions that are now known through the use of the layout.
        nx.draw_networkx_nodes(G,
                               pos,
                               node_shape=shape,
                               nodelist=nodelist,
                               node_size=sizes)
    
    # Draw the edges between the nodes and label them
    nx.draw_networkx_edges(G,pos)
    nx.draw_networkx_labels(G, pos)
    
    plt.xlim(-2, 2) # Expand limits if large nodes spill over plot.
    plt.ylim(-2, 2) # Expand limits if large nodes spill over plot.
    
    plt.show()