Search code examples
pytorchgraphviztorchvision

Display PyTorch model with multiple outputs using torchviz make_dots


I have a model with multiple outputs, 4 to be exact:

 def forward(self, x):
      outputs = []
      for conv, act in zip(self.Convolutions, self.Activations):
           y = conv(x)
           outputs.append(act(y))
      return outputs

I wanted to display it using make_dot from torchviz:

 from torchviz import make_dot
 generator = ...
 batch = next(iter(generator))
 input, output = batch["input"].to(device, dtype=torch.float), batch["output"].to(device, dtype=torch.float)
 dot = make_dot(model(input), params=dict(model.named_parameters()))

But I get the following error:

 File "/opt/local/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/torchviz/dot.py", line 37, in make_dot
 output_nodes = (var.grad_fn,) if not isinstance(var, tuple) else tuple(v.grad_fn for v in var)
 AttributeError: 'list' object has no attribute 'grad_fn'

Obviously a list does not have a grad_fn function, but according to this discussion, I can return a list of outputs.

What am I doing wrong?


Solution

  • Model can return a list, but make_dot wants a Tensor. If output components have similar shape, I suggest to use torch.cat on it.