Search code examples
pythonnetwork-programmingpytorchconv-neural-network

How to decide the 'input_size' parameter of torchsummary.summary(model=model.policy, input_size=(int, int, int))?


This is my CNN network printed by 'print(model.policy)':

CnnPolicy(
  (actor): Actor(
    (features_extractor): CustomCNN(
      (cnn): Sequential(
        (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
        (1): ReLU()
        (2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
        (3): ReLU()
        (4): Flatten(start_dim=1, end_dim=-1)
      )
      (linear): Sequential(
        (0): Linear(in_features=6, out_features=128, bias=True)
        (1): ReLU()
      )
    )
    (mu): Sequential(
      (0): Linear(in_features=128, out_features=256, bias=True)
      (1): ReLU()
      (2): Linear(in_features=256, out_features=128, bias=True)
      (3): ReLU()
      (4): Linear(in_features=128, out_features=3, bias=True)
      (5): Tanh()
    )
  )

When I try to print the network architecture using torchsummary.summary(model=model.policy, input_size=(1, 32, 32)). I got the following error: RuntimeError: mat1 and mat2 shapes cannot be multiplied (2x50176 and 6x128)

I have tried lots of 'input_size' combinations, but all were wrong.

I want to know how to choose the 'input-size' parameter?


Solution

  • This is not a problem of the summary, but a problem of your network. I think you got confused with the layer count due to Flatten() layer and its second argument.

    I recommend you to assemble your network layer by layer and test it by inputing a random x = torch.from_numpy(np.random.rand(batch_dim, channel_dim, spatial1, spatial2) and see if it works well together.

    Usually Flatten is used to flatten channels and spatial dimension but not batch dimension. You flatten channels and ONE spatial dimension, this is probably not what you want.

    Also, check that your input channels fit to the previous output channels. I can debug your network if you provide an example that is copy-and-paste, not just the structure.

    The Linear layer takes currently 6 values, but Flatten (previous layer) return 64 * first_spatial_dim. This does not work well together.

    You got the error only now, since when calling summary() its the very first time that you network is actually used. It faily not because of the function, but because of the ill-connected layers. When using fully connected layers and CNNs you need to have the input shape beforehand so you can build the network fitting your data, not guessing afterwards.

    Good luck!