Search code examples
pythonmatplotlibnetworkx

Draw node shape and node color by attribute using networkx


In a Graph G i have a set of nodes. Some of them have a Attribute Type which can be MASTER or DOC. Others do not have the a Type define:

>>> import networkx as nx
>>> import matplotlib.pyplot as plt
>>> G=nx.Graph()
[...]
>>> G.node['ART1']
{'Type': 'MASTER'}
>>> G.node['ZG1']
{'Type': 'DOC'}
>>> G.node['MG1']
{}

Afterwards I plot the Graph using

>>> nx.draw(G,with_labels = True)
>>> plt.show()

Now i get a graph with red Circles. How can I get e.g. blue cylces for ART red squares for DOC purple cylces for everything undefined in my plot?


Solution

  • There are various ways to select nodes based on their attributes. Here is how to do it with get_node_attributes and a list comprehension to take the subset. The drawing functions then accept a nodelist argument.

    It should be easy enough to extend to a broader set of conditions or modify the appearance of each subset as suits your needs based on this approach

    import networkx as nx
    
    # define a graph, some nodes with a "Type" attribute, some without.
    G = nx.Graph()
    G.add_nodes_from([1,2,3], Type='MASTER')
    G.add_nodes_from([4,5], Type='DOC')
    G.add_nodes_from([6])
    
    
    # extract nodes with specific setting of the attribute
    master_nodes = [n for (n,ty) in \
        nx.get_node_attributes(G,'Type').iteritems() if ty == 'MASTER']
    doc_nodes = [n for (n,ty) in \
        nx.get_node_attributes(G,'Type').iteritems() if ty == 'DOC']
    # and find all the remaining nodes.
    other_nodes = list(set(G.nodes()) - set(master_nodes) - set(doc_nodes))
    
    # now draw them in subsets  using the `nodelist` arg
    pos = nx.spring_layout(G)
    nx.draw_networkx_nodes(G, pos, nodelist=master_nodes, \
        node_color='red', node_shape='o')
    nx.draw_networkx_nodes(G, pos, nodelist=doc_nodes, \
        node_color='blue', node_shape='o')
    nx.draw_networkx_nodes(G, pos, nodelist=other_nodes, \
        node_color='purple', node_shape='s')