I built a TransformerEncoder model, and it changes the output's sequence length if I use "with torch.no_grad()" during the evaluation mode.
My model details:
class TransEnc(nn.Module):
def __init__(self,ntoken: int,encoder_embedding_dim: int,max_item_count: int,encoder_num_heads: int,encoder_hidden_dim: int,encoder_num_layers: int,padding_idx: int,dropout: float = 0.2):
super().__init__()
self.encoder_embedding = nn.Embedding(ntoken, encoder_embedding_dim, padding_idx=padding_idx)
self.pos_encoder = PositionalEncoding(encoder_embedding_dim, max_item_count, dropout)
encoder_layers = nn.TransformerEncoderLayer(encoder_embedding_dim, encoder_num_heads, encoder_hidden_dim, dropout, batch_first=True)
self.transformer_encoder = nn.TransformerEncoder(encoder_layers, encoder_num_layers)
self.encoder_embedding_dim = encoder_embedding_dim
def forward(self,src: torch.Tensor,src_key_padding_mask: torch.Tensor = None) -> torch.Tensor:
src = self.encoder_embedding(src.long()) * math.sqrt(self.encoder_embedding_dim)
src = self.pos_encoder(src)
src = self.transformer_encoder(src, src_key_padding_mask=src_key_padding_mask)
with
batch_size = 32
ntoken = 4096
encoder_embedding_dim = 256
max_item_count = 64 # max sequence length with padding
encoder_num_heads = 8
encoder_hidden_dim = 256
encoder_num_layers = 4
padding_idx = 0
I have a tensor (src) containing 32 word-level tokenized sentences (with different paddings) with a shape of (32,64)(batch_size,max_item_count).
When I activate training mode with "model.train()", set "src_key_padding_mask = src == tokenizer.pad_token_id" and run "logits = model(src = src, src_key_padding_mask = src_key_padding_mask)", I get logits with an expected shape of (32,64,256)(batch_size,max_item_count,encoder_embedding_dim).
However, when I activate evaluation mode with "model.eval()", set "src_key_padding_mask = src == tokenizer.pad_token_id" and run with "torch.no_grad(): logits = model(src = src, src_key_padding_mask = src_key_padding_mask)", I get different logits' shapes every time like (32,31,256), (32,25,256), etc. I want to get logits with a shape of (32,64,256). How can I solve this problem?
OS: Windows 10 x64
Python: 3.10.12
Torch: Tried on both 1.13.1+cu117 and 2.0.1+cu117, but the problem is still the same.
Downgrading worked! If you suffer from this problem, try
pip uninstall torch torchvision torchaudio
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu116