Search code examples
huggingface-transformerstext-classificationbert-language-model

Extend BERT or any transformer model using manual features


I have been doing a thesis in my citation classifications. I just implemented Bert model for the classification of citations. I have 4 output classes and I give an input sentence and my model returns an output that tells the category of citation. Now my supervisor gave me another task.

You have to search that whether it is possible to extend BERT or any transformer model using manual features. e.g. You are currently giving a sentence as the only input followed by its class. What if you can give a sentence, and some other features as input; as we do in other classifiers?

I need some guidance about this problem. How can I add an extra feature in my Bert model and the feature would be categorical not numerical.


Solution

  • The are several ways to achieve that. I will explain just two in the following answer:

    1. Add category as a token:
      The idea of this approach is rather simple when transformer models like BERT are able to produce contextualized embeddings for a given sentence, why can't we incorporate categorical features as text as well? For example, you use the title of a cited paper as input and also want to incorporate the research area of the paper to provide more context:
    "Attention is all you need. [Computer Science] [Machine Translation]" -> BERT
    

    To do that, I would add the categories of your new feature as separate tokens to BERT (that is not required but reduces the sequence length) and fine-tune it for a few epochs:

    from transformers import BertTokenizer, BertForSequenceClassification
    
    my_categories = ["[Computer Science]", "[Machine Translation]"]
    sentence="Attention is all you need. [Computer Science] [Machine Translation]"
    
    t= BertTokenizer.from_pretrained("bert-base-cased")
    m=BertForSequenceClassification.from_pretrained("bert-base-cased")
    # tokenized without separate tokens
    print(len(t(sentence)["input_ids"]))
    
    # tokenized without separate tokens
    t.add_tokens(my_categories)
    print(len(t(sentence)["input_ids"]))
    
    # Extend embedding layer of model
    m.resize_token_embeddings(len(t.get_vocab()))
    
    # Training...
    

    Output:

    18
    12
    Embedding(28998, 768, padding_idx=0)
    
    1. Separate Embedding layer:
      A more traditional way is to hold an embedding for each category and concatenate (or any other method to combine features) it with the contextualized output of BERT before you feed it to the classification layer. For this approach, you can simply copy the code from huggingfaces BertForSequenceClassification class (or whatever class you are using) and make the required changes:
    import torch
    from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
    from transformers import BertPreTrainedModel, BertModel
    from typing import Optional
    
    class MyBertForSequenceClassification(BertPreTrainedModel):
        def __init__(self, config):
            super().__init__(config)
            self.num_labels = config.num_labels
            self.config = config
    
            self.bert = BertModel(config)
            classifier_dropout = (
                config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
            )
            self.dropout = torch.nn.Dropout(classifier_dropout)
            
            # Modified +20
            self.classifier = torch.nn.Linear(config.hidden_size +20, config.num_labels)
    
            # Modified 50 different categories embedding dimension 20 
            self.my_categorical_feature = torch.nn.Embedding(50,20)
    
            # Initialize weights and apply final processing
            self.post_init()
    
        # Modified new parameter categorical_feature_ids
        def forward(
            self,
            input_ids: Optional[torch.Tensor] = None,
            attention_mask: Optional[torch.Tensor] = None,
            token_type_ids: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.Tensor] = None,
            head_mask: Optional[torch.Tensor] = None,
            inputs_embeds: Optional[torch.Tensor] = None,
            labels: Optional[torch.Tensor] = None,
            output_attentions: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
            categorical_feature_ids = None,
        ):
    
            return_dict = return_dict if return_dict is not None else self.config.use_return_dict
    
            outputs = self.bert(
                input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids,
                position_ids=position_ids,
                head_mask=head_mask,
                inputs_embeds=inputs_embeds,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
    
            pooled_output = outputs[1]
    
            # Modified get embeddings
            my_categorical_embedding = self.my_categorical_feature(categorical_feature_ids)
            my_categorical_embedding = self.dropout(my_categorical_embedding)
    
            pooled_output = self.dropout(pooled_output)
            
            # Modified concatenate contextualized embeddings from BERT and your categorical embedding
            pooled_output = torch.cat((pooled_output, my_categorical_embedding), dim=-1)
    
            logits = self.classifier(pooled_output)
    
            loss = None
            if labels is not None:
                if self.config.problem_type is None:
                    if self.num_labels == 1:
                        self.config.problem_type = "regression"
                    elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                        self.config.problem_type = "single_label_classification"
                    else:
                        self.config.problem_type = "multi_label_classification"
    
                if self.config.problem_type == "regression":
                    loss_fct = MSELoss()
                    if self.num_labels == 1:
                        loss = loss_fct(logits.squeeze(), labels.squeeze())
                    else:
                        loss = loss_fct(logits, labels)
                elif self.config.problem_type == "single_label_classification":
                    loss_fct = CrossEntropyLoss()
                    loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
                elif self.config.problem_type == "multi_label_classification":
                    loss_fct = BCEWithLogitsLoss()
                    loss = loss_fct(logits, labels)
            if not return_dict:
                output = (logits,) + outputs[2:]
                return ((loss,) + output) if loss is not None else output
    
            return {
                "loss":loss,
                "logits":logits,
                "hidden_states":outputs.hidden_states,
                "attentions":outputs.attentions,
            }
    

    You can use this class just as the BertForSerquenceClassification class, the only difference is, that it expects categorical_feature_ids as additional input:

    from transformers import BertTokenizer, BertForSequenceClassification
    
    t= BertTokenizer.from_pretrained("bert-base-cased")
    m= MyBertForSequenceClassification.from_pretrained("bert-base-cased")
    
    # batch with two sentences (i.e. the citation text you have already used) 
    i = t(["paper title 1", "paper title 2"], padding=True, return_tensors="pt")
    
    # We assume that the first sentence (i.e. paper title 1) belongs to category 23 and the second sentence to category 42
    # You probably want to use a dictionary in your own code 
    i["categorical_feature_ids"] = torch.tensor([23,42])
    
    print(m(**i))
    

    Output:

    {'loss': None, 
    'logits': tensor([[ 0.6069, -0.1878], [ 0.6347, -0.2608]], grad_fn=<AddmmBackward0>), 
    'hidden_states': None, 
    'attentions': None}