Search code examples
pythonnetworkxgraphviz

graphviz python neural network layer alignment


I'm trying to draw specific weight connections in MLP net. What I want to achieve is

centered

Straightforward graphviz drawing tilts the graph w.r.t. drawn edges

tilted

from graphviz import Graph

graph = Graph(directory='graphs', format='png',
              graph_attr=dict(ranksep='2', rankdir='LR', color='white', splines='line'),
              node_attr=dict(label='', shape='circle', width='0.1'))


def draw_cluster(name, length):
    with graph.subgraph(name=f'cluster_{name}') as c:
        c.attr(label=name)
        for i in range(length):
            c.node(f'{name}_{i}')


draw_cluster('input', 10)
draw_cluster('output', 4)

source_active = [0, 1, 2, 3]
sink_active = [2, 3]

for i_input in source_active:
    for i_output in sink_active:
        graph.edge(f'input_{i_input}', f'output_{i_output}')

graph.view()

If I add invisible edges between not connected weights, I do force my graph being centered:
centered with invisible edges

for i_input in set(range(10)).difference(source_active):
    for i_output in set(range(4)).difference(sink_active):
        graph.edge(f'input_{i_input}', f'output_{i_output}', style='invis')

But at what cost! My layers can have >1000 neurons with only tens of connections. Maybe networkx can help, I haven't played with it.

Similar questions that don't help me:

  1. cluster location in graphviz python
  2. GraphViz - alignment of subgraph

Solution

  • I tried your answer:

    from graphviz import Graph
    
    graph = Graph(directory='graphs', format='png',
                  graph_attr=dict(ranksep='2', rankdir='LR', color='white', splines='line'),
                  node_attr=dict(label='', shape='circle', width='0.1'))
    
    def draw_cluster(name, length):
        with graph.subgraph(name=f'cluster_{name}') as c:
            c.attr(label=name)
            for i in range(length):
                c.node(f'{name}_{i}')
    
    draw_cluster('input', 10)
    draw_cluster('output', 4)
    
    source_active = [0, 1, 2, 3]
    sink_active = [2, 3]
    
    for i_input in source_active:
        for i_output in sink_active:
            graph.edge(f'input_{i_input}', f'output_{i_output}')
    
    def central_neurons(layer_size: int):
        if layer_size % 2 == 1:
            return {layer_size // 2}
        else:
            return {layer_size // 2, layer_size // 2 - 1}
    
    for source_id in central_neurons(layer_size=10):
        for sink_id in central_neurons(layer_size=4):
            graph.edge(f'input_{source_id}', f'output_{sink_id}',
                       constraint='true', style='invis')
    
    graph.view()
    

    and got the result:
    MLP network alignment

    I think we can add more settings to make the result more centered. For example we can add rank='same' to align neurons in a cluster and constraint='false' for edges so that they do not break the alignment of clusters in the center.

    from graphviz import Graph
    
    graph = Graph(directory='graphs', format='png',
                  graph_attr=dict(ranksep='2', rankdir='LR', color='white', splines='line'),
                  node_attr=dict(label='', shape='circle', width='0.1'),
                  edge_attr=dict(constraint='false'))
    
    def draw_cluster(name, length):
        with graph.subgraph(name=f'cluster_{name}') as c:
            c.attr(label=name, rank='same')
            for i in range(length):
                c.node(f'{name}_{i}')
    
    draw_cluster('input', 10)
    draw_cluster('output', 4)
    
    source_active = [0, 1, 2, 3]
    sink_active = [2, 3]
    
    for i_input in source_active:
        for i_output in sink_active:
            graph.edge(f'input_{i_input}', f'output_{i_output}')
    
    def central_neurons(layer_size: int):
        if layer_size % 2 == 1:
            return {layer_size // 2}
        else:
            return {layer_size // 2, layer_size // 2 - 1}
    
    for source_id in central_neurons(layer_size=10):
        for sink_id in central_neurons(layer_size=4):
            graph.edge(f'input_{source_id}', f'output_{sink_id}',
                       constraint='true', style='invis')
    graph.view()
    

    Result:
    MLP network alignment