Search code examples
pythonnumpymatplotlibnetworkx

Arrows on edges in a graph using Networkx


I have a graph with nodes and edges. The code colors the array Edges with the array Weights as shown in the current output. Is it possible to put arrows on the array elements in Edges as displayed in the expected output? I want arrows on specific edges according to Edges, not all.

import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.cm import ScalarMappable

N = 1

def pos():
    x, y = 1, N + 3 - 1
    for _ in range(2 * N * (N + 1)):
        yield (x, y)
        y -= (x + 2) // (N + 3)
        x = (x + 2) % (N + 3)

G = nx.Graph()
it_pos = pos()
for u in range(2 * N * (N + 1)):
    G.add_node(u + 1, pos=next(it_pos))
    if u % (2 * N + 1) < N:
        for v in (u - 2 * N - 1, u - N - 1, u - N):
            if G.has_node(v + 1):
                G.add_edge(u + 1, v + 1)
    elif u % (2 * N + 1) == N:
        G.add_edge(u + 1, u - N + 1)
    elif u % (2 * N + 1) < 2 * N:
        for v in (u - 1, u - N - 1, u - N):
            G.add_edge(u + 1, v + 1)
    else:
        for v in (u - 1, u - N - 1):
            G.add_edge(u + 1, v + 1)


nx.draw(G, nx.get_node_attributes(G, 'pos'), with_labels=True, font_weight='bold')
Edges=np.array([[1,2],[1,3],[1,4]])
Weights=np.array([[1.7],[2.9],[8.6]])
flat_weights = Weights.flatten()
weights_normalized = [x / max(flat_weights) for x in flat_weights]
edge_weight_map = dict(zip([tuple(e) for e in Edges.tolist()],weights_normalized))
my_cmap = plt.cm.get_cmap('Oranges')
colors = my_cmap([edge_weight_map.get(tuple(e), 0) for e in Edges.tolist()])
pos = nx.get_node_attributes(G, 'pos')
sm = ScalarMappable(cmap=my_cmap, norm=plt.Normalize(0,max(flat_weights)))
nx.draw_networkx_edges(G, pos, edge_color=colors, 
                       edgelist=[tuple(e) for e in Edges.tolist()],
                       width=5);

plt.colorbar(sm)

The current output is

enter image description here

The expected output is

enter image description here


Solution

  • You are almost there.

    First, you will need to create a directed graph instead of an undirected graph:

    G = nx.DiGraph()
    

    Second, DiGraph objects are plotted with arrow heads by default, so you need to specify arrows=False in the call to nx.draw(...).

    nx.draw(G, nx.get_node_attributes(G, 'pos'), with_labels=True, font_weight='bold', arrows=False)
    

    If you now plot your selected edges separately, they come with arrow heads (per default).

    enter image description here