I have following toy example which deidentify a given text. I coded the following but the output does not makes sense. I am guessing that the way I am trying to get back the predicted text is incorrect.
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification
tokenizer = AutoTokenizer.from_pretrained("obi/deid_bert_i2b2", do_lower_case=True)
model = AutoModelForTokenClassification.from_pretrained("obi/deid_bert_i2b2")
text = "Patient John Doe visited the hospital on 01/05/2023 with complaints of chest pain."
encoded_input = tokenizer(text, padding=True, return_tensors='pt')
outputs = model(**encoded_input)
# Get the predicted labels
predicted_labels = torch.argmax(outputs.logits, dim=2).squeeze()
# Convert the predicted labels back to text
predicted_tokens = tokenizer.batch_decode(predicted_labels, skip_special_tokens=True)
# Print the predicted tokens
print(predicted_tokens)
Output:
['[unused33]', '[unused33]', '[unused33]', '[unused7]', '[unused29]', '[unused33]', '[unused33]', '[unused33]', '[unused33]', '[unused33]', '[unused35]', '[unused12]', '[unused35]', '[unused12]', '[unused23]', '[unused33]', '[unused33]', '[unused33]', '[unused33]', '[unused33]', '[unused33]', '[unused33]', '[unused33]']
You are right, your decoding step isn't correct. The class labels are not part of the tokenizer vocabulary but of the model config (id2label):
from transformers import AutoTokenizer, AutoModelForTokenClassification
t = AutoTokenizer.from_pretrained(model_id)
m = AutoModelForTokenClassification.from_pretrained(model_id)
encoded_text = t(text, return_tensors="pt")
with torch.no_grad():
logits = m(**encoded_text).logits
token_class_ids = logits.argmax(-1)
predictions = [(t.decode(t_id),m.config.id2label[c.item()]) for t_id, c in zip(encoded_text["input_ids"][0], token_class_ids[0])]
print(*predictions, sep="\n")
Output:
('[CLS]', 'O')
('Pat', 'O')
('##ient', 'O')
('John', 'B-PATIENT')
('Do', 'L-PATIENT')
('##e', 'O')
('visited', 'O')
('the', 'O')
('hospital', 'O')
('on', 'O')
('01', 'U-DATE')
('/', 'I-DATE')
('05', 'U-DATE')
('/', 'I-DATE')
('202', 'L-DATE')
('##3', 'O')
('with', 'O')
('complaints', 'O')
('of', 'O')
('chest', 'O')
('pain', 'O')
('.', 'O')
('[SEP]', 'O')
In case you are only interested in inference, you might want to check out the token classification pipeline:
from transformers import pipeline
model_id = "obi/deid_bert_i2b2"
text = "Patient John Doe visited the hospital on 01/05/2023 with complaints of chest pain."
p = pipeline("token-classification", model_id)
p(text)
Output:
[{'entity': 'B-PATIENT',
'score': 0.9976101,
'index': 3,
'word': 'John',
'start': 8,
'end': 12},
{'entity': 'L-PATIENT',
'score': 0.98856366,
'index': 4,
'word': 'Do',
'start': 13,
'end': 15},
{'entity': 'U-DATE',
'score': 0.99967885,
'index': 10,
'word': '01',
'start': 41,
'end': 43},
{'entity': 'I-DATE',
'score': 0.83500373,
'index': 11,
'word': '/',
'start': 43,
'end': 44},
{'entity': 'U-DATE',
'score': 0.9905285,
'index': 12,
'word': '05',
'start': 44,
'end': 46},
{'entity': 'I-DATE',
'score': 0.9776883,
'index': 13,
'word': '/',
'start': 46,
'end': 47},
{'entity': 'L-DATE',
'score': 0.9986461,
'index': 14,
'word': '202',
'start': 47,
'end': 50}]