Search code examples
huggingface-transformers

How is an object of type BaseModelOutput in huggingface transformer library subscriptable?


I have a code where self.encoder is an object of type ViTEncoder. ViTEncoder's forward method returns an object of type BaseModelOutput.

When I call the forward method of ViTEncoder, I get back an object that is subscriptable, how is this possible?

# self.encoder defined in the init method of class ViTModel
self.encoder = ViTEncoder(config)
# code snippet in the forward method of class ViTModel
encoder_outputs = self.encoder(
            embedding_output,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
sequence_output = encoder_outputs[0] # how is this subscriptable?
print(f"The encoder output is {encoder_outputs}") # prints the type of object as BaseModelOutput

This is the definition of BaseModelOutput class.


Solution

  • The BaseModelOutput class inherits from ModelOutput which implements _getitem_:

    from typing import List
    from dataclasses import dataclass
    
    @dataclass
    class X():
      bla:List[str] =None
      def __getitem__(self, i:int):
        return self.bla[i]
    
    x = X(["tldr", "no", "yes"])
    print(x[1])
    

    Output:

    no