Search code examples
pythonnlphuggingface-transformers

How to finetune a zero-shot model for text classification


I need a model that is able to classify text for an unknown number of classes (i.e. the number might grow over time). The entailment approach for zero-shot text classification seems to be the solution to my problem, the model I tried facebook/bart-large-mnli doesn't perform well on my annotated data. Is there a way to fine-tune it without losing the robustness of the model?

My dataset looks like this:

# http://groups.di.unipi.it/~gulli/AG_corpus_of_news_articles.html
World, "Afghan Army Dispatched to Calm Violence KABUL, Afghanistan - Government troops intervened in Afghanistan's latest outbreak of deadly fighting between warlords, flying from the capital to the far west on U.S. and NATO airplanes to retake an air base contested in the violence, officials said Sunday..."
Sports, "Johnson Helps D-Backs End Nine-Game Slide (AP) AP - Randy Johnson took a four-hitter into the ninth inning to help the Arizona Diamondbacks end a nine-game losing streak Sunday, beating Steve Trachsel and the New York Mets 2-0." 
Business, "Retailers Vie for Back-To-School Buyers (Reuters) Reuters - Apparel retailers are hoping their\back-to-school fashions will make the grade among\style-conscious teens and young adults this fall, but it could\be a tough sell, with students and parents keeping a tighter\hold on their wallets."

P.S.: This is an artificial question that was created because this topic came up in the comment section of this post which is related to this post.


Solution

  • Concept explanation

    Before I answer your question, it is crucial to understand how the entailment approach for zero-shot text classification works. This approach requires a model that was trained for NLI, which means, that it is able to determine if the hypothesis is:

    • supported,
    • not supported,
    • undetermined

    by a given premise [1]. You can verify that for the model you mentioned with the following code:

    from transformers import AutoModelForSequenceClassification, AutoTokenizer
    nli_model = AutoModelForSequenceClassification.from_pretrained('facebook/bart-large-mnli')
    # It will output three logits
    print(nli_model.classification_head.out_proj)
    # Each vector corresponds to the following labels
    print(nli_model.config.id2label)
    

    Output:

    Linear(in_features=1024, out_features=3, bias=True)
    {0: 'contradiction', 1: 'neutral', 2: 'entailment'}
    

    The entailment approach, proposed by Yin et. al, utilizes these NLI capabilities by using the text as premise and formulating a hypothesis for each possible class with the template:

    "the text is about {}”
    

    That means when you have a text and three potential classes, you will pass three sequences to the NLI model and compare the entailment logits to classify the text.

    Finetuning

    To fine-tune an NLI model on your annotated data, you, therefore, need to formulate your text classification task as an NLI task! That means, you need to generate premises and the labels need to be either contradiction or entailment. The contradiction label is included to avoid the model only seeing hypotheses that are entailed by their respective premise (i.e. the model needs to learn contraction to predict a low score for entailment for the zero-shot text classification task).

    The following code shows you an example of how to prepare your dataset:

    import random
    from datasets import load_dataset
    from transformers import  AutoTokenizer
    
    your_dataset = load_dataset("ag_news", split="test")
    id2labels = ["World", "Sports", "Business", "Sci/Tech"]
    your_dataset = your_dataset.map(lambda x: {"class": id2labels[x["label"]]}, remove_columns=["label"])
    
    print(your_dataset[0])
    
    # the relevant code
    t = AutoTokenizer.from_pretrained('facebook/bart-large-mnli')
    template = "This example is {}."
    
    def create_input_sequence(sample):
      text = sample["text"]
      label = sample["class"][0]
      contradiction_label = random.choice([x for x in id2labels if x!=label])
    
      encoded_sequence = t(text*2, [template.format(label), template.format(contradiction_label)])
      encoded_sequence["labels"] = [2,0]
      encoded_sequence["input_sentence"] = t.batch_decode(encoded_sequence.input_ids)
    
      return encoded_sequence
    
    train_dataset = your_dataset.map(create_input_sequence, batched=True, batch_size=1, remove_columns=["class", "text"])
    print(train_dataset[0])
    

    Output:

    {'text': "Fears for T N pension after talks Unions representing workers at Turner   Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul.", 
    'class': 'Business'}
    
    {'input_ids': [0, 597, 12541, 13, 255, 234, 4931, 71, 1431, 1890, 2485, 4561, 1138, 23, 6980, 1437, 1437, 188, 1250, 224, 51, 32, 128, 7779, 19051, 108, 71, 1431, 19, 35876, 4095, 933, 1853, 18059, 922, 4, 2, 2, 713, 1246, 16, 2090, 4, 2], 
    'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 
    'labels': 2, 
    'input_sentence': "<s>Fears for T N pension after talks Unions representing workers at Turner   Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul.</s></s>This example is Business.</s>"}
    

    Robustness

    Finetuning will obviously reduce the robustness (i.e. the ability to provide decent results for classes that weren't part of your fine-tuning dataset) of your model. To avoid that you could try:

    • To stop training before conversion and check if the performance is still sufficient for your needs.
    • WiSE-FT proposed by Wortsmann et. al. Pseudocode is shown in appendix A.