Search code examples
pythonmatplotlibplotgraphviz

Automatically assign color to nodes in Graphviz


I'm using Python and Graphviz to draw some cluster graph consist of nodes. I want to assign different colors to each node, dependent on an attribute, e.g. its x-coordinate.

Here's how I produce graph:

def add_nodes(graph, nodes):
    for n in nodes:
        if isinstance(n, tuple):
            graph.node(n[0], **n[1])
        else:
            graph.node(n)
    return graph

A = [[517, 1, [409], 10, 6], 
     [534, 1, [584], 10, 12], 
     [614, 1, [247], 11, 5], 
     [679, 1, [228], 13, 7], 
     [778, 1, [13], 14, 14]]

nodesgv = []

for node in A:
    nodesgv.append((str(node[0]),{'label': str(node[0]), 'color': ???, 'style': 'filled'}))

graph = functools.partial(gv.Graph, format='svg', engine='neato')
add_nodes(graph(), nodesgv).render(('img/test'))

And now I want to assign a color to each node with the ordering of the first value of each node. More specifically what I want is:

  • a red node (517)
  • a yellow node (534)
  • a green node (614)
  • a blue node (679)
  • and a purple node (778)

5 differently coloured nodes with numbers in them: red, yellow, green, blue, purple. 517, 534, 614, 679, 778

I know how to assign colors to the graph, but what I'm looking for is something similar to the c=x part when using matplotlib.

Problem is I'm not able to know the number of nodes (clusters) beforehand, so for example if I've got 7 nodes, I still want a graph with 7 nodes that start from a red one, and end with a purple one.

plt.scatter(x, y, c=x, s=node_sizes)

So is there any attribute in Graphviz that can do this?

Or can anyone tell me how does the colormap in matplotlib work?

Sorry for the lack of clarity. T^T


Solution

  • Oh I figured out a way to get what I want. Just for recording and for someone else may have a same problem(?) Can just rescale a color map and assign the corresponding index (of color) to the nodes.

    def add_nodes(graph, nodes):
    for n in nodes:
        if isinstance(n, tuple):
            graph.node(n[0], **n[1])
        else:
            graph.node(n)
    return graph
    
    A = [[517, 1, [409], 10, 6], 
         [534, 1, [584], 10, 12], 
         [614, 1, [247], 11, 5], 
         [679, 1, [228], 13, 7], 
         [778, 1, [13], 14, 14]]
    
    nodesgv = []
    Arange = [ a[0] for a in A]
    norm = mpl.colors.Normalize(vmin = min(Arange), vmax = max(Arange))
    cmap = cm.jet
    
    for index, i in enumerate(A):
        x = i[0]
        m = cm.ScalarMappable(norm = norm, cmap = cmap)
        mm = m.to_rgba(x)
        M = colorsys.rgb_to_hsv(mm[0], mm[1], mm[2])
        nodesgv.append((str(i[0]),{'label': str((i[1])), 'color': "%f, %f, %f" % (M[0], M[1], M[2]), 'style': 'filled'}))
    
    graph = functools.partial(gv.Graph, format='svg', engine='neato')
    add_nodes(graph(), nodesgv).render(('img/test'))