Search code examples
networkxdrawk-meansdirected-graph

Draw Networkx Directed Graph using clustering labels as color scheme


I need help drawing a networkx directed graph. I have a directed graph which I create from a dataframe that looks as the following:

source    target    weight
ip_1      ip_2      3
ip_1      ip_3      6
ip_4      ip_3      7
.
.
.

Afterwards, I have clustered this graph using elbow+kmeans, after converting the nodes into embeddings using Node2Vec:

https://github.com/eliorc/node2vec

At the end, I have this resulting dataframe:

source    target    weight    source_kmeans_label    target_kmeans_label    elbow_optimal_k
ip_1      ip_2      3         0                      1                      12
ip_1      ip_3      6         2                      0                      12
ip_4      ip_3      7         0                      3                      12
.
.
.

I want to visualize (draw) this graph (source, target, weight) using different colors based on the elbow value; so for the example above, I will have 12 different colors. I really appreciate any help to achieve this, thanks.


Solution

  • You can use a seaborn palette to generate 12 different RGB color values and then create a column called color in your dataframe based on the weight values:

    import seaborn as sns
    import networkx as nx
    from pyvis.network import Network
    
    palette = sns.color_palette("husl", n_colors=12)  # n_colors is your elbow value
    

    assuming you dataframe is called df, you can add the new column color based on weight column as follows:

    df['color'] = df.apply(lambda row: palette[row['weight'] - 1], axis=1)
    
    

    Now that you have an RGB value for each edge, first you need to make your graph from the dataframe and then you can visualize the graph using pyvis:

    G = nx.from_pandas_edgelist(df, 'source', 'target', edge_attr='color', create_using=nx.DiGraph())
    N = Network(height='100%', width='100%', bgcolor='white', font_color='black', directed=True)
    
    for n in G.nodes:
        N.add_node(n)
    for e, attrs in G.edges.data():
        N.add_edge(e[0], e[1], color=attrs['color'])
    
    N.write_html('path/to/your_graph.html')