Search code examples
pythondeep-learningpytorchtransformer-modelencoder-decoder

with torch.no_grad() Changes Sequence Length During Evaluation Mode


I built a TransformerEncoder model, and it changes the output's sequence length if I use "with torch.no_grad()" during the evaluation mode.

My model details:

class TransEnc(nn.Module):
    def __init__(self,ntoken: int,encoder_embedding_dim: int,max_item_count: int,encoder_num_heads: int,encoder_hidden_dim: int,encoder_num_layers: int,padding_idx: int,dropout: float = 0.2):
        super().__init__()
        self.encoder_embedding = nn.Embedding(ntoken, encoder_embedding_dim, padding_idx=padding_idx)
        self.pos_encoder = PositionalEncoding(encoder_embedding_dim, max_item_count, dropout)
        encoder_layers = nn.TransformerEncoderLayer(encoder_embedding_dim, encoder_num_heads, encoder_hidden_dim, dropout, batch_first=True) 
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, encoder_num_layers)
        self.encoder_embedding_dim = encoder_embedding_dim

def forward(self,src: torch.Tensor,src_key_padding_mask: torch.Tensor = None) -> torch.Tensor:
        src = self.encoder_embedding(src.long()) * math.sqrt(self.encoder_embedding_dim)
        src = self.pos_encoder(src)
        src = self.transformer_encoder(src, src_key_padding_mask=src_key_padding_mask)

with

batch_size = 32
ntoken = 4096
encoder_embedding_dim = 256
max_item_count = 64 # max sequence length with padding
encoder_num_heads = 8
encoder_hidden_dim = 256
encoder_num_layers = 4
padding_idx = 0

I have a tensor (src) containing 32 word-level tokenized sentences (with different paddings) with a shape of (32,64)(batch_size,max_item_count).

When I activate training mode with "model.train()", set "src_key_padding_mask = src == tokenizer.pad_token_id" and run "logits = model(src = src, src_key_padding_mask = src_key_padding_mask)", I get logits with an expected shape of (32,64,256)(batch_size,max_item_count,encoder_embedding_dim).

However, when I activate evaluation mode with "model.eval()", set "src_key_padding_mask = src == tokenizer.pad_token_id" and run with "torch.no_grad(): logits = model(src = src, src_key_padding_mask = src_key_padding_mask)", I get different logits' shapes every time like (32,31,256), (32,25,256), etc. I want to get logits with a shape of (32,64,256). How can I solve this problem?

OS: Windows 10 x64

Python: 3.10.12

Torch: Tried on both 1.13.1+cu117 and 2.0.1+cu117, but the problem is still the same.


Solution

  • Downgrading worked! If you suffer from this problem, try

    pip uninstall torch torchvision torchaudio
    pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu116