Search code examples
python3dnetworkx

How to convert 2D networkx graph to interactive 3D in python?


I have already built a network 2D graph using networkx in python.

Code used to build:

import pandas as pd
import matplotlib as mpl

links_data = pd.read_csv("https://raw.githubusercontent.com/johnsnow09/network_graph/refs/heads/main/links_filtered.csv")


G = nx.from_pandas_edgelist(links_data, 'var1', 'var2')

cmap = mpl.colormaps['Set3'].colors # this has 12 colors for 11 categories

cat_colors = dict(zip(links_data['Category'].unique(), cmap))

colors = (links_data
 .drop_duplicates('var1').set_index('var1')['Category']
 .map(cat_colors)
 .reindex(G.nodes)
)

nx.draw(G, with_labels=True, node_color=colors, node_size=200,
        edge_color='black', linewidths=.5, font_size=2.5)

It gives below image as output enter image description here

How can I convert it into 3D graph so that I can get better view of Network Relations in the graph ?

Appreciate any help!


Solution

  • pyvis is a straight solution you can try on.

    from pyvis.network import Network
    import networkx as nx
    import pandas as pd
    import matplotlib as mpl
    
    
    links_data = pd.read_csv("https://raw.githubusercontent.com/johnsnow09/network_graph/refs/heads/main/links_filtered.csv")
    
    
    G = nx.from_pandas_edgelist(links_data, 'var1', 'var2')
    
    
    net = Network(notebook=True, height="750px", width="100%")
    
    
    cmap = mpl.colormaps['Set3'].colors
    cat_colors = dict(zip(links_data['Category'].unique(), cmap))
    
    colors = (links_data
        .drop_duplicates('var1').set_index('var1')['Category']
        .map(cat_colors)
        .reindex(G.nodes)
    )
    
    for node, color in zip(G.nodes(), colors):
        net.add_node(node, color=f'#{mpl.colors.rgb2hex(color)[1:]}')  
    
    
    for edge in G.edges():
        net.add_edge(*edge)
    
    
    net.show("interactive_graph.html")
    

    interactive_graph.html Toogle between nodes in a interactive way. Go throught documenation, Helps you more ways to customize

    Edit 1:

    Ploty:

    You can create Plotly 3D scatter plot for nodes and Combine traces into a figure.

    import plotly.graph_objects as go
    import networkx as nx
    import pandas as pd
    import matplotlib as mpl   
    
    links_data = pd.read_csv("https://raw.githubusercontent.com/johnsnow09/network_graph/refs/heads/main/links_filtered.csv")
    
    
    G = nx.from_pandas_edgelist(links_data, 'var1', 'var2')
    
    pos_3d = nx.spring_layout(G, dim=3, seed=42)
    
    
    x_nodes = [pos_3d[node][0] for node in G.nodes]
    y_nodes = [pos_3d[node][1] for node in G.nodes]
    z_nodes = [pos_3d[node][2] for node in G.nodes]
    
    # Extract edges
    edge_x = []
    edge_y = []
    edge_z = []
    
    for edge in G.edges():
        x0, y0, z0 = pos_3d[edge[0]]
        x1, y1, z1 = pos_3d[edge[1]]
        edge_x += [x0, x1, None]
        edge_y += [y0, y1, None]
        edge_z += [z0, z1, None]
    
    
    cmap = mpl.colormaps['Set3'].colors
    cat_colors = dict(zip(links_data['Category'].unique(), cmap))
    
    node_colors = (links_data
        .drop_duplicates('var1').set_index('var1')['Category']
        .map(cat_colors)
        .reindex(G.nodes)
    )
    
    node_colors = [mpl.colors.rgb2hex(color) for color in node_colors]
    
    # Create Plotly 3D scatter plot for nodes
    node_trace = go.Scatter3d(
        x=x_nodes,
        y=y_nodes,
        z=z_nodes,
        mode='markers',
        marker=dict(size=8, color=node_colors, line_width=0.5),
        text=list(G.nodes),  # Add node names for hover
        hoverinfo='text'
    )
    
    
    edge_trace = go.Scatter3d(
        x=edge_x,
        y=edge_y,
        z=edge_z,
        mode='lines',
        line=dict(color='black', width=0.5),
        hoverinfo='none'
    )
    
    
    fig = go.Figure(data=[edge_trace, node_trace])
    
    
    fig.update_layout(
        showlegend=False,
        margin=dict(l=0, r=0, t=0, b=0),
        scene=dict(
            xaxis=dict(showbackground=False),
            yaxis=dict(showbackground=False),
            zaxis=dict(showbackground=False)
        )
    )
    
    
    fig.show()
    

    <iframe src="https://bhargav-ravinuthala.github.io/ploty/" width="400" height="400"></iframe>