Search code examples
pythonpytorchconv-neural-networkdataloader

What should be the input shape for 3D CNN on a sequence of images?


https://pytorch.org/docs/stable/generated/torch.nn.Conv3d.html#conv3d Describes that the input to do convolution on 3D CNN is (N,Cin,D,H,W). Imagine if I have a sequence of images which I want to pass to 3D CNN. Am I right that:

  1. N -> number of sequences (mini batch)
  2. Cin -> number of channels (3 for rgb)
  3. D -> Number of images in a sequence
  4. H -> Height of one image in the sequence
  5. W -> Width of one image in the sequence

The reason why I am asking is that when I stack image tensors: a = torch.stack([img1, img2, img3, img4, img5]) I get shape of a torch.Size([5, 3, 396, 247]), so is it compulsory to reshape my tensor to torch.Size([3, 5, 396, 247]) so that number of channels would go first or it does not matter inside the Dataloader?

Note that Dataloader would add one more dimension automatically which would correspond to N.


Solution

  • Yes it matters, you need to ensure that dimensions are ordered correctly (assuming you use DataLoader's default collate function). One way to do this is to invoke torch.stack using dim=1 instead of the default of dim=0. For example

    a = torch.stack([img1, img2, img3, img4, img5], dim=1)
    

    results in a being the desired shape of [3, 5, 396, 247].