Search code examples
pytorchnlphuggingface-transformers

Determining contents of decoder_hidden_states from T5ForConditionalGeneration


I'm using the Huggingface T5ForConditionalGeneration model without modification.

I want to compute mean pooling over the last hidden state of the T5 decoder, but I can't determine which part of the decoder_hidden_states contains what I'm looking for.

I want to do something like this:

# Prepare batch data
sources = batch_df['Source'].tolist()
tokenized_input = self.tokenizer(sources, return_tensors='pt', padding=True, truncation=True, max_length=self.max_length).to('cuda')
input_ids = tokenized_input['input_ids'].to('cuda')
attention_mask = tokenized_input['attention_mask'].to('cuda')

input_batch = {
    'input_ids': input_ids, 
    'attention_mask': attention_mask,
    'do_sample': False,
    'num_beams': 1,
    'eos_token_id': self.tokenizer.eos_token_id,
    'pad_token_id': self.tokenizer.pad_token_id,
    'max_length': self.max_output_length,
    'output_scores': True,
    'return_dict_in_generate': True,
    'output_hidden_states': True,
}
outputs = self.model.generate(**input_batch)

# Retrieve the decoder hidden states
decoder_last_hidden_state = outputs.decoder_hidden_states[-1]  # Last layer's hidden states

# Compute the mean of the hidden states across the sequence length dimension
mean_pooled_output = torch.mean(decoder_last_hidden_state, dim=1, keepdim=False)

This approach works for the encoder, but for the decoder, decoder_hidden_states[-1] is a tuple of tensors, not a tensor.

When I first inspected the tuples, there were 10 tuples, and each tuple contained 7 tensors.

When I inspected the dimensions, like this:

for tuple_number in range(n):  # Checking the tuples
    print(f"Tuple {layer_number}:")
    for i, tensor in enumerate(outputs.decoder_hidden_states[layer_number]):
        print(f"  Tuple {i} in Layer {layer_number}: shape {tensor.shape}")

the outputs were all like this:

Tuple 0:
  Tensor 0 in Tuple 0: shape torch.Size([2, 1, 512])
  Tensor 1 in Tuple 0: shape torch.Size([2, 1, 512])
  Tensor 2 in Tuple 0: shape torch.Size([2, 1, 512])
  Tensor 3 in Tuple 0: shape torch.Size([2, 1, 512])
  Tensor 4 in Tuple 0: shape torch.Size([2, 1, 512])
  Tensor 5 in Tuple 0: shape torch.Size([2, 1, 512])
  Tensor 6 in Tuple 0: shape torch.Size([2, 1, 512])
Tuple 1:
  Tensor 0 in Tuple 1: shape torch.Size([2, 1, 512])
  Tensor 1 in Tuple 1: shape torch.Size([2, 1, 512])
  Tensor 2 in Tuple 1: shape torch.Size([2, 1, 512])
  Tensor 3 in Tuple 1: shape torch.Size([2, 1, 512])
  Tensor 4 in Tuple 1: shape torch.Size([2, 1, 512])
  Tensor 5 in Tuple 1: shape torch.Size([2, 1, 512])
  Tensor 6 in Tuple 1: shape torch.Size([2, 1, 512])
. . .

512 is the max_length of my tokenizer, and 2 is my batch size. (I verified that 2 is the batch size because that number changed when I modified my batch size.)

Then, when I trimmed the length of my input strings to 10 characters, to my surprise, the number of tuples went from 10 to 39. When I trimmed the strings further to only 2 chars per string, the number of tuples didn't increase beyond 39. Then, when I doubled my input string length instead, the number of tuples went down to 7. So, it appears like the number of tuples corresponds to iterations of the decoder over some chunk size up to some limits.

So, if I wanted to compute mean pooling over the first token, it seems like I'd compute the mean over the last tensor of the first tuple. However, I don't understand exactly how the token length corresponds to the number of tuples.

How do I determine what exactly is represented by each of these tuples and tensors? I have not been successful in finding this information by going through the T5 source code.


Solution

  • I think whats happening is that T5 returns the hidden state per step of decoding. Therefore, the number of tuples should correspond to the longest generated sequence. You are most likely interested in the last decoding step and could take the last tuple.

    In that tuple you have a tuple of size num_layers + 1 (+1 for the final LayerNorm). The output of the last layer should be the last tuple entry.