Search code examples
pythonpytorchhuggingface-transformers

Get all labels / entity groups available to a model


I have the following code to get the named entity values from a given text:

from transformers import AutoTokenizer, AutoModelForTokenClassification
from transformers import pipeline

tokenizer = AutoTokenizer.from_pretrained("Davlan/distilbert-base-multilingual-cased-ner-hrl")
model = AutoModelForTokenClassification.from_pretrained("Davlan/distilbert-base-multilingual-cased-ner-hrl")
nlp = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="max")

example = "My name is Johnathan Smith and I work at Apple"
ner_results = nlp(example)
print(ner_results)

The following is the output:

[{'end': 26,
  'entity_group': 'PER',
  'score': 0.9994689,
  'start': 11,
  'word': 'Johnathan Smith'},
 {'end': 46,
  'entity_group': 'ORG',
  'score': 0.9983876,
  'start': 41,
  'word': 'Apple'}]

In the above example the labels / entitiy groups are ORG and PER. How to find all the labels / entitiy groups available?

Kindly advise.


Solution

  • You can get this information from the id2label property of your model config:

    model.config.id2label
    

    Output:

    {0: 'O',
     1: 'B-DATE',
     2: 'I-DATE',
     3: 'B-PER',
     4: 'I-PER',
     5: 'B-ORG',
     6: 'I-ORG',
     7: 'B-LOC',
     8: 'I-LOC'}
    

    P.S.: It seems like, even if the model has weights for classifying tokens as *-DATE, it is not able to do that because it was never trained on it.