Search code examples
python-3.xnetworkx

NetworkX - How to reproduce the drawing shape


import matplotlib.pyplot as plt
import networkx as nx

MDG = nx.MultiDiGraph()

# --------------------------------------------------------------------------------
# Forward edges
# --------------------------------------------------------------------------------
forward_edges = [
    ("sushi/0", "goes/1"), 
    ("sushi/0", "with/2"), 
    ("sushi/0", "wasabi/3"), 
    ("goes/1", "with/2"),
    ("goes/1", "wasabi/3"),
    ("with/2", "wasabi/3"),
]
MDG.add_edges_from(
    forward_edges,
    edge_color='b',
    weight='length',
    length=100,
)

# layout definition must come after adding all nodes/edges. 
# Otherwise Node X has no position error.
pos=nx.spring_layout(MDG, weight='length')
fig, ax = plt.subplots(figsize=(10, 5))

# --------------------------------------------------------------------------------
# Draw nodes & labels
# --------------------------------------------------------------------------------
nx.draw_networkx_nodes(
    MDG, 
    pos, 
    ax=ax,
    node_size=500,
    node_color="cyan",
)
nx.draw_networkx_labels(
    MDG, 
    pos, 
    ax=ax,
    font_weight='bold',
    font_size=12,
)

# --------------------------------------------------------------------------------
# Draw forward edges
# --------------------------------------------------------------------------------
nx.draw_networkx_edges(
    MDG, 
    pos, 
    ax=ax, 
    edgelist=forward_edges, 
    edge_color="b",
    arrowsize=20,
    arrowstyle="->",
#    connectionstyle='arc3, rad=0.25',
)
nx.draw_networkx_edge_labels(
    MDG, 
    pos,
    label_pos=0.2,
    edge_labels={
        ("sushi/0", "goes/1"): 0.3, 
        ("sushi/0", "with/2"): 0.1, 
        ("sushi/0", "wasabi/3"): 0.6, 
        ("goes/1", "with/2"): 0.75,
        ("goes/1", "wasabi/3"): 0.75,
        ("with/2", "wasabi/3"): 0.75,
    },
    font_color='b'
)
nx.set_edge_attributes(
    G=MDG,
    values={
        ("sushi/0", "goes/1", 0.1): {"label": 0.1},
        ("sushi/0", "with/2", 0.1): {"label": 0.2},
        ("sushi/0", "wasabi/3", 0.6): {"label": 0.6},
        ("goes/1", "with/2", 0.1): {"label": 0.1},
        ("goes/1", "wasabi/3", 0.75): {"label": 0.75},
        ("with/2", "wasabi/3", 0.75): {"label": 0.75},
    }
)

enter image description here

However, next time run the same code the shape is different.

enter image description here

Question

Is there a way to consistently reproduce the first shape? Apparently the seed parameter is the one to use. However, how can we get the value to set to seed parameter to consistently reproduce the first figure?

I suppose after the first figure is generated, need to find out the seed value actually used, then reuse it. But not sure how.


Solution

  • The value returned from nx.spring_layout is a dict you can save and re-use (or even modify) once you have a layout you want to keep. If you print it, you'll see it's a dictionary of numpy arrays that store the position of each node.

    pos = {'sushi/0': array([0.20632381, 0.0243153]),
           'goes/1': array([0.59708043, 0.2527625]),
           'with/2': array([0.19659576, 0.52970947]),
           'wasabi/3': array([-1., -0.80678727])}
    

    If you want to keep using the exact same node positions between runs of your program, you could pickle that dict and load it in place of calling nx.spring_layout. (Or for a small example like this, you could hard code it.)