Search code examples
pythonnetworkxbipartite

NetworkX bipartite color mixes up order


I have created a bipartite graph using NetworkX and would like to colour the two sets separately. I use the color() function from networkX bipartite module. However, the order of the nodes is different in the color dict than in the B.nodes, for example:

B.nodes = ['a', 1, 2, 3, 4, 'c', 'b']

bipartite.color(B) = {'a': 1, 1: 0, 2: 0, 'b': 1, 4: 0, 'c': 1, 3: 0}

This results in the graph being incorrectly coloured as below:

incorrectly colour graph

The code is as follows:

B = nx.Graph()
B.add_nodes_from([1,2,3,4], bipartite=0) # Add the node attribute "bipartite"
B.add_nodes_from(['a','b','c'], bipartite=1)
B.add_edges_from([(1,'a'), (1,'b'), (2,'b'), (2,'c'), (3,'c'), (4,'a')])
bottom_nodes, top_nodes = bipartite.sets(B)

color = bipartite.color(B)
color_list = []

for c in color.values():
    if c == 0:
        color_list.append('b')
    else:
        color_list.append('r')

# Draw bipartite graph
pos = dict()
color = []
pos.update( (n, (1, i)) for i, n in enumerate(bottom_nodes) ) # put nodes from X at x=1
pos.update( (n, (2, i)) for i, n in enumerate(top_nodes) ) # put nodes from Y at x=2

nx.draw(B, pos=pos, with_labels=True, node_color = color_list)
plt.show()

Is there something I'm missing?

Thanks.


Solution

  • Your color_list and node list(B.nodes) are in a different order when you draw your graph.

    color_list
    ['r', 'b', 'b', 'r', 'b', 'r', 'r']
    
    B.nodes
    NodeView((1, 2, 3, 4, 'a', 'b', 'c'))
    

    I created a color_list using B.nodes order using a dictionary and mapping bipartite sets from nodelist in B.

    B = nx.Graph()
    B.add_nodes_from([1,2,3,4], bipartite=0) # Add the node attribute "bipartite"
    B.add_nodes_from(['a','b','c'], bipartite=1)
    B.add_edges_from([(1,'a'), (1,'b'), (2,'b'), (2,'c'), (3,'c'), (4,'a')])
    bottom_nodes, top_nodes = bipartite.sets(B)
    
    color = bipartite.color(B)
    
    color_dict = {0:'b',1:'r'}
    
    color_list = [color_dict[i[1]] for i in B.nodes.data('bipartite')]
    
    # Draw bipartite graph
    pos = dict()
    color = []
    pos.update( (n, (1, i)) for i, n in enumerate(bottom_nodes) ) # put nodes from X at x=1
    pos.update( (n, (2, i)) for i, n in enumerate(top_nodes) ) # put nodes from Y at x=2
    
    nx.draw(B, pos=pos, with_labels=True, node_color = color_list)
    plt.show()
    

    Output:

    enter image description here