I've added the default tensorboard logger (from pytorch_lightning.loggers import TensorBoardLogger
) to my pytorch lightning Trainer with log_graph=True
When I train my model, the first view of my graph shows three blocks:
inputs => MyNetworkClassName => Outputs
So far so good.
But then, when I expand MyNetworkClassName it gives me absolutely everything that's going on in my net. That's a lot of arrows going everywhere. I would like to organize this graph into simpler blocks with expandable subgraphs. So in my case, where my network has a typical encoder - enhancer - decoder structure, I would like something more like this:
First graph:
zooming in on MyNetworkClassName:
zooming in on encoder:
zooming in on encoder_layer1:
What are my options here? Should I put everything into separate classes? Are there any commands that allow me to group certain actions together?
Refractoring code into classes also affects the tensorboard graph (where refractoring into methods does not). Typical example class that will show up as an expandable block:
class EncoderLayer(nn.Module):
"""Encoder layer class"""
def __init__(self, activation_function, kernel_num, kernel_size, idx):
self.layer = nn.Sequential(
kernel_num[idx + 1],
kernel_size=(kernel_size, 2),
nn.BatchNorm2d(kernel_num[idx + 1])
def forward(self, x):
return self.layer(x)