Search code examples
pythonhuggingface-transformers

How can I get indexes after getting NER results?


model = AutoModelForTokenClassification.from_pretrained("dbmdz/bert-large-cased-finetuned-conll03-english")
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
from transformers import LukeTokenizer
from transformers import PreTrainedTokenizerFast



label_list = [
    "O",       # Outside of a named entity
    "B-MISC",  # Beginning of a miscellaneous entity right after another miscellaneous entity
    "I-MISC",  # Miscellaneous entity
    "B-PER",   # Beginning of a person's name right after another person's name
    "I-PER",   # Person's name
    "B-ORG",   # Beginning of an organisation right after another organisation
    "I-ORG",   # Organisation
    "B-LOC",   # Beginning of a location right after another location
    "I-LOC"    # Location
]

sequence = "Hugging Face Inc. is a company based in New York City. Its headquarters are in DUMBO, therefore very" \
           "close to the Manhattan Bridge."

# Bit of a hack to get the tokens with the special tokens
tokens = tokenizer.tokenize(tokenizer.decode(tokenizer.encode(sequence)))
inputs = tokenizer.encode(sequence, return_tensors="pt")

outputs = model(inputs)[0]
predictions = torch.argmax(outputs, dim=2)

print([(token, label_list[prediction]) for token, prediction in zip(tokens, predictions[0].tolist())])

output:    [('[CLS]', 'O'), ('Hu', 'I-ORG'), ('##gging', 'I-ORG'), ('Face', 'I-ORG'), ('Inc', 'I-ORG'), 
    ('.', 'O'), ('is', 'O'), ('a', 'O'), ('company', 'O'), ('based', 'O'), ('in', 'O'), ('New', 'I-
    LOC'), ('York', 'I-LOC'), ('City', 'I-LOC'), ('.', 'O'), ('Its', 'O'), ('headquarters', 'O'), 
    ('are', 'O'), ('in', 'O'), ('D', 'I-LOC'), ('##UM', 'I-LOC'), ('##BO', 'I-LOC'), (',', 'O'), 
    ('therefore', 'O'), ('very', 'O'), ('##c', 'O'), ('##lose', 'O'), ('to', 'O'), ('the', 'O'), 
    ('Manhattan', 'I-LOC'), ('Bridge', 'I-LOC'), ('.', 'O'), ('[SEP]', 'O')]

I took an example from the Hugging Face Transformers documentation in order to understand how the library works. But I ran into a problem that I can't solve for a very long time. After getting the output that is in "print", I want to get the indices of the recognized entities for the "sequence" variable. How can I do this? Didn't find any method in the documentation, am I missing something?

For example:

('Hu', 'I-ORG'), ('##gging', 'I-ORG'), ('Face', 'I-ORG'), ('Inc', 'I-ORG') ---> (start: 0, end: 16)

Additional question: Should I get rid off ## (for example: ('##gging', 'I-ORG')) in my results? Or this is okay?


Solution

  • All you are trying to achieve is already available as tokenclassificationpipeline:

    from transformers import pipeline
    
    ner =  pipeline('token-classification', model='dbmdz/bert-large-cased-finetuned-conll03-english', tokenizer='dbmdz/bert-large-cased-finetuned-conll03-english')
    
    sentence = "Hugging Face Inc. is a company based in New York City. Its headquarters are in DUMBO, therefore very" \
               "close to the Manhattan Bridge."
    
    ner(sentence)
    

    Output:

    [{'end': 2,
      'entity': 'I-ORG',
      'index': 1,
      'score': 0.9995108,
      'start': 0,
      'word': 'Hu'},
     {'end': 7,
      'entity': 'I-ORG',
      'index': 2,
      'score': 0.98959744,
      'start': 2,
      'word': '##gging'},
     {'end': 12,
      'entity': 'I-ORG',
      'index': 3,
      'score': 0.9979704,
      'start': 8,
      'word': 'Face'},
     {'end': 16,
      'entity': 'I-ORG',
      'index': 4,
      'score': 0.9993759,
      'start': 13,
      'word': 'Inc'},
     {'end': 43,
      'entity': 'I-LOC',
      'index': 11,
      'score': 0.9993406,
      'start': 40,
      'word': 'New'},
     {'end': 48,
      'entity': 'I-LOC',
      'index': 12,
      'score': 0.99919283,
      'start': 44,
      'word': 'York'},
     {'end': 53,
      'entity': 'I-LOC',
      'index': 13,
      'score': 0.99934113,
      'start': 49,
      'word': 'City'},
     {'end': 80,
      'entity': 'I-LOC',
      'index': 19,
      'score': 0.9863364,
      'start': 79,
      'word': 'D'},
     {'end': 82,
      'entity': 'I-LOC',
      'index': 20,
      'score': 0.939624,
      'start': 80,
      'word': '##UM'},
     {'end': 84,
      'entity': 'I-LOC',
      'index': 21,
      'score': 0.9121385,
      'start': 82,
      'word': '##BO'},
     {'end': 122,
      'entity': 'I-LOC',
      'index': 29,
      'score': 0.983919,
      'start': 113,
      'word': 'Manhattan'},
     {'end': 129,
      'entity': 'I-LOC',
      'index': 30,
      'score': 0.99242425,
      'start': 123,
      'word': 'Bridge'}]
    

    You can also group the tokens by defining a aggregation strategy:

    ner(sentence, aggregation_strategy='simple')
    

    Output:

    [{'end': 16,
      'entity_group': 'ORG',
      'score': 0.9966136,
      'start': 0,
      'word': 'Hugging Face Inc'},
     {'end': 53,
      'entity_group': 'LOC',
      'score': 0.9992916,
      'start': 40,
      'word': 'New York City'},
     {'end': 84,
      'entity_group': 'LOC',
      'score': 0.946033,
      'start': 79,
      'word': 'DUMBO'},
     {'end': 129,
      'entity_group': 'LOC',
      'score': 0.98817164,
      'start': 113,
      'word': 'Manhattan Bridge'}]