Search code examples
nlphuggingface-transformershuggingface

How to get back the predicted text from model output in HuggingFace?


I have following toy example which deidentify a given text. I coded the following but the output does not makes sense. I am guessing that the way I am trying to get back the predicted text is incorrect.

import torch 
from transformers import AutoTokenizer, AutoModelForTokenClassification
tokenizer = AutoTokenizer.from_pretrained("obi/deid_bert_i2b2", do_lower_case=True)
model = AutoModelForTokenClassification.from_pretrained("obi/deid_bert_i2b2")

text = "Patient John Doe visited the hospital on 01/05/2023 with complaints of chest pain."

encoded_input = tokenizer(text, padding=True, return_tensors='pt')
outputs = model(**encoded_input)

# Get the predicted labels
predicted_labels = torch.argmax(outputs.logits, dim=2).squeeze()

# Convert the predicted labels back to text
predicted_tokens = tokenizer.batch_decode(predicted_labels, skip_special_tokens=True)

# Print the predicted tokens
print(predicted_tokens)

Output:

['[unused33]', '[unused33]', '[unused33]', '[unused7]', '[unused29]', '[unused33]', '[unused33]', '[unused33]', '[unused33]', '[unused33]', '[unused35]', '[unused12]', '[unused35]', '[unused12]', '[unused23]', '[unused33]', '[unused33]', '[unused33]', '[unused33]', '[unused33]', '[unused33]', '[unused33]', '[unused33]']

Solution

  • You are right, your decoding step isn't correct. The class labels are not part of the tokenizer vocabulary but of the model config (id2label):

    from transformers import AutoTokenizer, AutoModelForTokenClassification
    
    t = AutoTokenizer.from_pretrained(model_id)
    m = AutoModelForTokenClassification.from_pretrained(model_id)
    
    encoded_text = t(text, return_tensors="pt")
    
    with torch.no_grad():
        logits = m(**encoded_text).logits
    
    token_class_ids = logits.argmax(-1)
    
    predictions = [(t.decode(t_id),m.config.id2label[c.item()]) for t_id, c in zip(encoded_text["input_ids"][0], token_class_ids[0])]
    print(*predictions, sep="\n")
    

    Output:

    ('[CLS]', 'O')
    ('Pat', 'O')
    ('##ient', 'O')
    ('John', 'B-PATIENT')
    ('Do', 'L-PATIENT')
    ('##e', 'O')
    ('visited', 'O')
    ('the', 'O')
    ('hospital', 'O')
    ('on', 'O')
    ('01', 'U-DATE')
    ('/', 'I-DATE')
    ('05', 'U-DATE')
    ('/', 'I-DATE')
    ('202', 'L-DATE')
    ('##3', 'O')
    ('with', 'O')
    ('complaints', 'O')
    ('of', 'O')
    ('chest', 'O')
    ('pain', 'O')
    ('.', 'O')
    ('[SEP]', 'O')
    

    In case you are only interested in inference, you might want to check out the token classification pipeline:

    from transformers import pipeline
    
    model_id = "obi/deid_bert_i2b2"
    
    text = "Patient John Doe visited the hospital on 01/05/2023 with complaints of chest pain."
    
    p = pipeline("token-classification", model_id)
    p(text)
    

    Output:

    [{'entity': 'B-PATIENT',
      'score': 0.9976101,
      'index': 3,
      'word': 'John',
      'start': 8,
      'end': 12},
     {'entity': 'L-PATIENT',
      'score': 0.98856366,
      'index': 4,
      'word': 'Do',
      'start': 13,
      'end': 15},
     {'entity': 'U-DATE',
      'score': 0.99967885,
      'index': 10,
      'word': '01',
      'start': 41,
      'end': 43},
     {'entity': 'I-DATE',
      'score': 0.83500373,
      'index': 11,
      'word': '/',
      'start': 43,
      'end': 44},
     {'entity': 'U-DATE',
      'score': 0.9905285,
      'index': 12,
      'word': '05',
      'start': 44,
      'end': 46},
     {'entity': 'I-DATE',
      'score': 0.9776883,
      'index': 13,
      'word': '/',
      'start': 46,
      'end': 47},
     {'entity': 'L-DATE',
      'score': 0.9986461,
      'index': 14,
      'word': '202',
      'start': 47,
      'end': 50}]