I need to change the shape of tensor from [2, 48, 196] to [2, 48, 14,14]. I read there a "unflatten" in pytorch. But I couldn't understand how to use it. Is there any example?
Here is example for your question.
import torch
input = torch.randn([2,48,196])
unflatten = torch.nn.Unflatten(2, (14,14))
output = unflatten(input)
If you check output.shape, the shape is [2,48,14,14].
Unflatten function is to expand specific dim to a desired shape. In your case, you want to expand the shape 196 in "dim 2" to new shape of the unflatten dimension "(14,14)".
There are two parameters in Unflatten function.
Therefore, your Unflatten function should be looked like unflatten = torch.nn.Unflatten(2, (14,14))