Based on the documentation provided here, https://github.com/huggingface/transformers/blob/v4.21.3/src/transformers/modeling_outputs.py#L101, how can i read all the outputs, last_hidden_state (), pooler_output and hidden_state. in my sample code below, i get the outputs
from transformers import BertModel, BertConfig
config = BertConfig.from_pretrained("xxx", output_hidden_states=True)
model = BertModel.from_pretrained("xxx", config=config)
outputs = model(inputs)
when i print one of the output (sample below) . i looked through the documentation to see if i can use some functions of this class to just get the last_hidden_state values , but i'm not sure of the type here.
the value for the last_hidden_state =
tensor([[...
is it some class or tuple or array . how can i get the values or array of values such as
[0, 1, 2, 3 , ...]
BaseModelOutputWithPoolingAndNoAttention(
last_hidden_state=tensor([
[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29]
...
hidden_states= ...
The BaseModelOutputWithPoolingAndCrossAttentions
you retrieve is class that inherits from OrderedDict (code) that holds pytorch tensors. You can access the keys of the OrderedDict
like properties of a class and, in case you do not want to work with Tensors, you can them to python lists or numpy. Please have a look at the example below:
from transformers import BertTokenizer, BertModel
t = BertTokenizer.from_pretrained("bert-base-cased")
m = BertModel.from_pretrained("bert-base-cased")
i = t("This is a test", return_tensors="pt")
o = m(**i, output_hidden_states=True)
print(o.keys())
print(type(o.last_hidden_state))
print(o.last_hidden_state.tolist())
print(o.last_hidden_state.detach().numpy())
Output:
odict_keys(['last_hidden_state', 'pooler_output', 'hidden_states'])
<class 'transformers.modeling_outputs.BaseModelOutputWithPoolingAndCrossAttentions'>
<class 'torch.Tensor'>
[[[0.36328405141830444, 0.018902940675616264, 0.1893523931503296, ..., 0.09052444249391556, 1.4617693424224854, 0.0774402841925621]]]
[[[ 0.36328405 0.01890294 0.1893524 ... -0.0259465 0.38701165
0.19099694]
[ 0.30656984 -0.25377586 0.76075834 ... 0.2055152 0.29494798
0.4561815 ]
[ 0.32563183 0.02308523 0.665546 ... 0.34597045 -0.0644953
0.5391255 ]
[ 0.3346715 -0.02526359 0.12209094 ... 0.50101244 0.36993945
0.3237842 ]
[ 0.18683438 0.03102166 0.25582778 ... 0.5166369 -0.1238729
0.4419385 ]
[ 0.81130844 0.4746894 -0.03862225 ... 0.09052444 1.4617693
0.07744028]]]