Search code examples
pythonnumpynetworkx

Issue with colorbar in Networkx


The following code generates adjacency matrix using nx.grid_2d. I use an array P to color the nodes but there is an error while generating colorbar. I want the colorbar to display the range of node colors.

import numpy as np
import networkx as nx
import matplotlib
import matplotlib.pyplot as plt

G = nx.grid_2d_graph(3,3)
new_nodes = {e: n for n, e in enumerate(G.nodes, start=1)}
new_edges = [(new_nodes[e1], new_nodes[e2]) for e1, e2 in G.edges]
G = nx.Graph()
G.add_edges_from(new_edges)
#nx.draw(G, with_labels=True)

A1 = nx.adjacency_matrix(G) 
A=A1.toarray()
print([A]) 

fig, ax = plt.subplots(1, 1)
P=np.array([1,2,5,4,5,6,7,2,10])
D=nx.draw(G, with_labels=True, node_color=[P], node_size=1000, cmap='Blues')
fig.colorbar(D)
plt.show()

enter image description here

The error is

in <module>
    fig.colorbar(D)

  File "C:\Users\USER\anaconda3\lib\site-packages\matplotlib\figure.py", line 1176, in colorbar
    cb = cbar.Colorbar(cax, mappable, **cb_kw)

  File "C:\Users\USER\anaconda3\lib\site-packages\matplotlib\colorbar.py", line 1171, in __init__
    if mappable.get_array() is not None:

AttributeError: 'NoneType' object has no attribute 'get_array'

Solution

  • nx.draw appears to return None instead of a mappable. You can simulate a suitable mappable using the ScalarMappable class in plt.cm.

    enter image description here

    import numpy as np
    import matplotlib
    import matplotlib.pyplot as plt
    import networkx as nx
    
    G = nx.grid_2d_graph(3,3)
    new_nodes = {e: n for n, e in enumerate(G.nodes, start=1)}
    new_edges = [(new_nodes[e1], new_nodes[e2]) for e1, e2 in G.edges]
    G = nx.Graph()
    G.add_edges_from(new_edges)
    #nx.draw(G, with_labels=True)
    
    A1 = nx.adjacency_matrix(G)
    A=A1.toarray()
    print([A])
    
    fig, ax = plt.subplots(1, 1)
    P=np.array([1,2,5,4,5,6,7,2,10])
    D=nx.draw(G, with_labels=True, node_color=[P], node_size=1000, cmap='Blues')
    
    # -------------------------------------
    norm = matplotlib.colors.Normalize(vmin=np.min(P), vmax=np.max(P))
    sm = plt.cm.ScalarMappable(cmap='Blues', norm=norm)
    sm.set_array([])
    fig.colorbar(sm)
    # -------------------------------------
    
    plt.show()