model = AutoModelForTokenClassification.from_pretrained("dbmdz/bert-large-cased-finetuned-conll03-english")
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
from transformers import LukeTokenizer
from transformers import PreTrainedTokenizerFast
label_list = [
"O", # Outside of a named entity
"B-MISC", # Beginning of a miscellaneous entity right after another miscellaneous entity
"I-MISC", # Miscellaneous entity
"B-PER", # Beginning of a person's name right after another person's name
"I-PER", # Person's name
"B-ORG", # Beginning of an organisation right after another organisation
"I-ORG", # Organisation
"B-LOC", # Beginning of a location right after another location
"I-LOC" # Location
]
sequence = "Hugging Face Inc. is a company based in New York City. Its headquarters are in DUMBO, therefore very" \
"close to the Manhattan Bridge."
# Bit of a hack to get the tokens with the special tokens
tokens = tokenizer.tokenize(tokenizer.decode(tokenizer.encode(sequence)))
inputs = tokenizer.encode(sequence, return_tensors="pt")
outputs = model(inputs)[0]
predictions = torch.argmax(outputs, dim=2)
print([(token, label_list[prediction]) for token, prediction in zip(tokens, predictions[0].tolist())])
output: [('[CLS]', 'O'), ('Hu', 'I-ORG'), ('##gging', 'I-ORG'), ('Face', 'I-ORG'), ('Inc', 'I-ORG'),
('.', 'O'), ('is', 'O'), ('a', 'O'), ('company', 'O'), ('based', 'O'), ('in', 'O'), ('New', 'I-
LOC'), ('York', 'I-LOC'), ('City', 'I-LOC'), ('.', 'O'), ('Its', 'O'), ('headquarters', 'O'),
('are', 'O'), ('in', 'O'), ('D', 'I-LOC'), ('##UM', 'I-LOC'), ('##BO', 'I-LOC'), (',', 'O'),
('therefore', 'O'), ('very', 'O'), ('##c', 'O'), ('##lose', 'O'), ('to', 'O'), ('the', 'O'),
('Manhattan', 'I-LOC'), ('Bridge', 'I-LOC'), ('.', 'O'), ('[SEP]', 'O')]
I took an example from the Hugging Face Transformers documentation in order to understand how the library works. But I ran into a problem that I can't solve for a very long time. After getting the output that is in "print", I want to get the indices of the recognized entities for the "sequence" variable. How can I do this? Didn't find any method in the documentation, am I missing something?
For example:
('Hu', 'I-ORG'), ('##gging', 'I-ORG'), ('Face', 'I-ORG'), ('Inc', 'I-ORG') ---> (start: 0, end: 16)
Additional question: Should I get rid off ## (for example: ('##gging', 'I-ORG')) in my results? Or this is okay?
All you are trying to achieve is already available as tokenclassificationpipeline:
from transformers import pipeline
ner = pipeline('token-classification', model='dbmdz/bert-large-cased-finetuned-conll03-english', tokenizer='dbmdz/bert-large-cased-finetuned-conll03-english')
sentence = "Hugging Face Inc. is a company based in New York City. Its headquarters are in DUMBO, therefore very" \
"close to the Manhattan Bridge."
ner(sentence)
Output:
[{'end': 2,
'entity': 'I-ORG',
'index': 1,
'score': 0.9995108,
'start': 0,
'word': 'Hu'},
{'end': 7,
'entity': 'I-ORG',
'index': 2,
'score': 0.98959744,
'start': 2,
'word': '##gging'},
{'end': 12,
'entity': 'I-ORG',
'index': 3,
'score': 0.9979704,
'start': 8,
'word': 'Face'},
{'end': 16,
'entity': 'I-ORG',
'index': 4,
'score': 0.9993759,
'start': 13,
'word': 'Inc'},
{'end': 43,
'entity': 'I-LOC',
'index': 11,
'score': 0.9993406,
'start': 40,
'word': 'New'},
{'end': 48,
'entity': 'I-LOC',
'index': 12,
'score': 0.99919283,
'start': 44,
'word': 'York'},
{'end': 53,
'entity': 'I-LOC',
'index': 13,
'score': 0.99934113,
'start': 49,
'word': 'City'},
{'end': 80,
'entity': 'I-LOC',
'index': 19,
'score': 0.9863364,
'start': 79,
'word': 'D'},
{'end': 82,
'entity': 'I-LOC',
'index': 20,
'score': 0.939624,
'start': 80,
'word': '##UM'},
{'end': 84,
'entity': 'I-LOC',
'index': 21,
'score': 0.9121385,
'start': 82,
'word': '##BO'},
{'end': 122,
'entity': 'I-LOC',
'index': 29,
'score': 0.983919,
'start': 113,
'word': 'Manhattan'},
{'end': 129,
'entity': 'I-LOC',
'index': 30,
'score': 0.99242425,
'start': 123,
'word': 'Bridge'}]
You can also group the tokens by defining a aggregation strategy:
ner(sentence, aggregation_strategy='simple')
Output:
[{'end': 16,
'entity_group': 'ORG',
'score': 0.9966136,
'start': 0,
'word': 'Hugging Face Inc'},
{'end': 53,
'entity_group': 'LOC',
'score': 0.9992916,
'start': 40,
'word': 'New York City'},
{'end': 84,
'entity_group': 'LOC',
'score': 0.946033,
'start': 79,
'word': 'DUMBO'},
{'end': 129,
'entity_group': 'LOC',
'score': 0.98817164,
'start': 113,
'word': 'Manhattan Bridge'}]