Search code examples
pytorchhuggingface-transformerstransformer-model

How can/should we weight classes in HuggingFace token classification (entity recognition)?


I'm training a token classification (AKA named entity recognition) model with the HuggingFace Transformers library, with a customized data loader.

Like most NER datasets (I'd imagine?) there's a pretty significant class imbalance: A large majority of tokens are other - i.e. not an entity - and of course there's a little variation between the different entity classes themselves.

As we might expect, my "accuracy" metrics are getting distorted quite a lot by this: It's no great achievement to get 80% token classification accuracy if 90% of your tokens are other... A trivial model could have done better!

I can calculate some additional and more insightful evaluation metrics - but it got me wondering... Can/should we somehow incorporate these weights into the training loss? How would this be done using a typical *ForTokenClassification model e.g. BERTForTokenClassification?


Solution

  • This is actually a really interesting question, since it seems there is no intention (yet) to modify losses in the models yourself. Specifically for BertForTokenClassification, I found this code segment:

    loss_fct = CrossEntropyLoss()
    # ...
    loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
    

    To actually change the loss computation and add other parameters, e.g., the weights you mention, you can go about either one of two ways:

    • You can modify a copy of transformers locally, and install the library from there, which makes this only a small change in the code, but potentially quite a hassle to change parts during different experiments, or
    • You return your logits (which is the case by default), and calculate your own loss outside of the actual forward pass of the huggingface model. In this case, you need to be aware of any potential propagation from the loss calculated within the forward call, but this should be within your power to change.