does anyone here know if the torch.squeeze
function respects the batch (e.g. first) dimension? From some inline code it seems it does not.. but maybe someone else knows the inner workings better than I do.
Btw, the underlying problem is that I have tensor of shape (n_batch, channel, x, y, 1)
. I want to remove the last dimension with a simple function, so that I end up with a shape of (n_batch, channel, x, y)
.
A reshape is of course possible, or even selecting the last axis. But I want to embed this functionality in a layer so that I can easily add it to a ModuleList
or Sequence
object.
EDIT: just found out that for Tensorflow (2.5.0) the function tf.linalg.diag
DOES respect batch dimension. Just a FYI that it might differ per function you are using
No! squeeze doesn't respect the batch dimension. It's a potential source of error if you use squeeze when the batch dimension may be 1. Rule of thumb is that only classes and functions in torch.nn respect batch dimensions by default.
This has caused me headaches in the past. I recommend using reshape
or only using squeeze
with the optional input dimension argument. In your case you could use .squeeze(4)
to only remove the last dimension. That way nothing unexpected happens. Squeeze without the input dimension has led me to unexpected results, specifically when
nn.DataParallel
is being used (in which case batch size for a particular instance may be reduced to 1)