Search code examples
huggingface-transformersbert-language-model

How to get hidden layer/state outputs from a Bert model?


Based on the documentation provided here, https://github.com/huggingface/transformers/blob/v4.21.3/src/transformers/modeling_outputs.py#L101, how can i read all the outputs, last_hidden_state (), pooler_output and hidden_state. in my sample code below, i get the outputs

from transformers import BertModel, BertConfig

config = BertConfig.from_pretrained("xxx", output_hidden_states=True)
model = BertModel.from_pretrained("xxx", config=config)

outputs = model(inputs)

when i print one of the output (sample below) . i looked through the documentation to see if i can use some functions of this class to just get the last_hidden_state values , but i'm not sure of the type here.

the value for the last_hidden_state =

tensor([[...

is it some class or tuple or array . how can i get the values or array of values such as

[0, 1, 2, 3 , ...]
BaseModelOutputWithPoolingAndNoAttention(
last_hidden_state=tensor([
        [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 
         11, 12, 13, 14, 15, 16, 17,
         18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
        [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 
         15, 16, 17,
         18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29]
        ...
        hidden_states= ...
        

Solution

  • The BaseModelOutputWithPoolingAndCrossAttentions you retrieve is class that inherits from OrderedDict (code) that holds pytorch tensors. You can access the keys of the OrderedDict like properties of a class and, in case you do not want to work with Tensors, you can them to python lists or numpy. Please have a look at the example below:

    from transformers import BertTokenizer, BertModel
    
    t = BertTokenizer.from_pretrained("bert-base-cased")
    m = BertModel.from_pretrained("bert-base-cased")
    
    i = t("This is a test", return_tensors="pt")
    o = m(**i, output_hidden_states=True)
    
    print(o.keys())
    print(type(o.last_hidden_state))
    print(o.last_hidden_state.tolist())
    print(o.last_hidden_state.detach().numpy())
    

    Output:

    odict_keys(['last_hidden_state', 'pooler_output', 'hidden_states'])
    <class 'transformers.modeling_outputs.BaseModelOutputWithPoolingAndCrossAttentions'>
    <class 'torch.Tensor'>
    [[[0.36328405141830444, 0.018902940675616264, 0.1893523931503296, ..., 0.09052444249391556, 1.4617693424224854, 0.0774402841925621]]]
    [[[ 0.36328405  0.01890294  0.1893524  ... -0.0259465   0.38701165
        0.19099694]
      [ 0.30656984 -0.25377586  0.76075834 ...  0.2055152   0.29494798
        0.4561815 ]
      [ 0.32563183  0.02308523  0.665546   ...  0.34597045 -0.0644953
        0.5391255 ]
      [ 0.3346715  -0.02526359  0.12209094 ...  0.50101244  0.36993945
        0.3237842 ]
      [ 0.18683438  0.03102166  0.25582778 ...  0.5166369  -0.1238729
        0.4419385 ]
      [ 0.81130844  0.4746894  -0.03862225 ...  0.09052444  1.4617693
        0.07744028]]]