Search code examples
deep-learningpytorchtorchtorchvisionpytorch-lightning

What should be the input to torch.nn.MultiheadAttention if I have an RGB image?


I have a PyTorch Tensor that's a batch of B images of dimensions 3xHxW. So the Tensor's shape is (B, 3, H, W).

I would like to reshape this vector to be an input to the nn.MultiheadAttention module from the torch library.

In the official documentation for torch.nn.MultiheadAttention, the input and output tensors' shapes are determined according to batch_first:

batch_first – If True, then the input and output tensors are provided as (batch, seq, feature). Default: False (seq, batch, feature).

What does seq and feature exactly mean here? And how can I get them from my image.

(This will also help me determine the parameters of nn.MultiheadAttention: embed_dim and num_heads.

This is my current initialization:

self.attention = torch.nn.MultiheadAttention(embed_dim= 256 * 4, num_heads= 4)

And in my forward function:

x = self.attention(x, x, x)

What should I reshape X to?


Solution

  • For each image extract image patches and flatten them. Your batch size is the number of images. Your sequence length is the number of patches per image. Your feature size is the length of a flattened patch.