I have a tensor of embeddings that I want to reduce into a smaller number of embeddings. I am working in a batched environment. The tensor shape is B, F, D where B is the number of items in batch, F is the number of embeddings and D is the dimension. I want to learn a reduction to B, F-n, D.
e.g.
import torch
B = 10
F = 20
F_desired = 17
D = 64
x = torch.randn(B, F, D)
# torch.Size([50, 20, 64])
reduction = torch.?
y = reduction(x)
print(y.shape)
# torch.Size([50, 20, 64])
I think a 1x1 convolution would make sense here, but not sure how to confirm it was actually doing what I expected? So would love to hear if it's the right approach / if there are better approaches
reduction = torch.nn.Conv1d(
in_channels=F,
out_channels=F_desired,
kernel_size=1,
)
A 1d conv with a kernel size of 1 accomplishes this:
B = 10
F = 20
F_desired = 17
D = 64
x = torch.randn(B, F, D)
reduction1 = nn.Conv1d(F, F_desired, 1)
x1 = reduction1(x)
print(x1.shape)
> torch.Size([10, 17, 64])
You could also do a linear layer, provided you permute the axes:
reduction2 = nn.Linear(F, F_desired)
x2 = reduction2(x.permute(0,2,1)).permute(0,2,1)
print(x2.shape)
> torch.Size([10, 17, 64])
Note that if your convolution kernel is size 1
, these are actually equivalent operations
reduction2.weight.data = reduction1.weight.squeeze().data
reduction2.bias.data = reduction1.bias.data
x2 = reduction2(x.permute(0,2,1)).permute(0,2,1)
print(torch.allclose(x1,x2, atol=1e-6))
> True