Search code examples
plotneural-networkgraphviz

How to properly draw Residual Neural Network blocks with Graphviz?


I am trying to create a graph showing the ConvNet architecture with residual connections. I use the following graph definition.

digraph Model {
    node [shape=box];

    input [label="Input"];
    n1 [label="Conv2d(256, BN, ReLU)"];
    n2 [label="Conv2D(256, BN, ReLU)"];
    n3 [label="Conv2D(128, BN, ReLU)"];
    n4 [label="Conv2D(128, BN, ReLU)"];
    n5 [label="GlobalPool2d"];
    n6 [label="Flatten"];
    top [label="Dense(1, Sigmoid)"];

    add1 [label="Add"];
    add2 [label="Add"];

    input -> n1;
    n1 -> n2;
    n1 -> add1;
    n2 -> add1;
    add1 -> n3;

    n3 -> n4;
    n3 -> add2;
    n4 -> add2;
    add2 -> n5;

    n5 -> n6 -> top;
}

The generated plot looks like the following image shows.

enter image description here

The problem is that the residual connections shift the convolution layers to the left. I wonder if it's possible to align the boxes by vertical axis? So all the layers are on the same vertical line and the residual connections go around. I've tried to do some manipulations with rank and rankdir but without any luck.

Could you help me with it? Or maybe point to the relevant part of the documentation where I can read how to properly do what I need?


Solution

  • The weight attribute will do what you want. Change two lines to:

       n1 -> add1 [weight=0];
       n3 -> add2 [weight=0];
    

    You'll get this: enter image description here