Search code examples
pythonpytorchtorch

is unsqueeze(-1) equal to squeeze(1) in pytorch?


I don't understand what it means to unsqeeze(-1). Ran into this piece of code from HuggingFace -

def encode_batch(model, tokenizer, sentences, device):
    input_ids = tokenizer(sentences, padding=True, max_length=512, truncation=True, return_tensors="pt",
                          add_special_tokens=True).to(device)
    features = model(**input_ids)[0]
    features = torch.sum(features[:, 1:, :] * input_ids["attention_mask"][:, 1:].unsqueeze(-1), dim=1) / torch.clamp(
        torch.sum(input_ids["attention_mask"][:, 1:], dim=1, keepdims=True), min=1e-9)
    return features

and I'm trying to understand what this calculation means.

I read that unsqueeze turns an n.d. tensor into an (n+1).d. one by adding an extra dimension of depth 1 and squeeze turns a n.d. tensor into an (n-1).d. one by removing one dimension of depth 1.

So what does it mean unsqueeze(-1) ?

according to this logic, if it adds -1 dimensions so it lower by one dimension so it means that unsqueeze(-1) == squeeze(1) ?

saw those examples first second and of course when over the documentation.

Thank you!


Solution

  • The parameter of unsqueeze and squeeze functions are not a number of dimensions to add/remove, it tells on which place should one dimension be added/removed. The parameter -1 just means in the end, so squeeze(-1) would remove the last dimension and unsqueeze(-1) would add a new dimension after the current last.

    Some examples:

    a = torch.randn((4, 4, 1))
    a.shape               # torch.size([4, 4, 1])
    a.squeeze(2).shape    # torch.Size([4, 4]), dimension 2 has been removed
    a.squeeze(-1).shape   # torch.Size([4, 4]), last dimension has been removed (same effect)
    a.unsqueeze(0).shape  # torch.Size([1, 4, 4, 1]), one new dimension as first
    a.unsqueeze(1).shape  # torch.Size([4, 1, 4, 1]), one new dimension as second
    a.unsqueeze(3).shape  # torch.Size([4, 4, 1, 1]), one new dimension as fourth
    a.unsqueeze(-1).shape # torch.Size([4, 4, 1, 1]), one new dimension as last (same effect)