Search code examples
tensorboardpytorch-lightning

Organize tensorboard graph with pytorch lightning


I've added the default tensorboard logger (from pytorch_lightning.loggers import TensorBoardLogger) to my pytorch lightning Trainer with log_graph=True.

When I train my model, the first view of my graph shows three blocks:

inputs => MyNetworkClassName => Outputs

So far so good.

But then, when I expand MyNetworkClassName it gives me absolutely everything that's going on in my net. That's a lot of arrows going everywhere. I would like to organize this graph into simpler blocks with expandable subgraphs. So in my case, where my network has a typical encoder - enhancer - decoder structure, I would like something more like this:

  • First graph:

    1. inputs => MyNetworkClassName => Outputs
  • zooming in on MyNetworkClassName:

    1. encoder => enhancer -> decoder
  • zooming in on encoder:

    1. encoder_layer1 => encoder_layer2 => ...
  • zooming in on encoder_layer1:

    1. conv2d => batchnorm

What are my options here? Should I put everything into separate classes? Are there any commands that allow me to group certain actions together?


Solution

  • Refractoring code into classes also affects the tensorboard graph (where refractoring into methods does not). Typical example class that will show up as an expandable block:

    class EncoderLayer(nn.Module):
    """Encoder layer class"""
    
        def __init__(self, activation_function, kernel_num, kernel_size, idx):
            super().__init__()
            self.layer = nn.Sequential(
                ComplexConv2d(
                kernel_num[idx],
                kernel_num[idx + 1],
                kernel_size=(kernel_size, 2),
            ),
            nn.BatchNorm2d(kernel_num[idx + 1])
            activation_function,
        )
    
    def forward(self, x):
        return self.layer(x)