Search code examples
python-3.xpytorchbert-language-modelnamed-entity-recognitioncrf

How to add simple custom pytorch-crf layer on top of TokenClassification model using pytorch and Trainer


I followed this link, but its implemented in Keras.

Cannot add CRF layer on top of BERT in keras for NER

Model description

Is it possible to add simple custom pytorch-crf layer on top of TokenClassification model. It will make the model more robust.

from torchcrf import CRF

model_checkpoint = "dslim/bert-base-NER"
tokenizer = BertTokenizer.from_pretrained(model_checkpoint,add_prefix_space=True)
config = BertConfig.from_pretrained(model_checkpoint, output_hidden_states=True)
bert_model = BertForTokenClassification.from_pretrained(model_checkpoint,id2label=id2label,label2id=label2id,ignore_mismatched_sizes=True)


class BERT_CRF(nn.Module):
    
    def __init__(self, bert_model, num_labels):
        super(BERT_CRF, self).__init__()
        self.bert = bert_model
        self.dropout = nn.Dropout(0.25)
        
        self.classifier = nn.Linear(4*768, num_labels)

        self.crf = CRF(num_labels, batch_first = True)
    
    def forward(self, input_ids, attention_mask,  labels=None, token_type_ids=None):
        outputs = self.bert(input_ids, attention_mask=attention_mask)
        
        **sequence_output = torch.cat((outputs[1][-1], outputs[1][-2], outputs[1][-3], outputs[1][-4]),-1)**
        sequence_output = self.dropout(sequence_output)
        
        emission = self.classifier(sequence_output) # [32,256,17]
        labels=labels.reshape(attention_mask.size()[0],attention_mask.size()[1])
        
        if labels is not None:    
            loss = -self.crf(log_soft(emission, 2), labels, mask=attention_mask.type(torch.uint8), reduction='mean')
            prediction = self.crf.decode(emission, mask=attention_mask.type(torch.uint8))
            return [loss, prediction]
                
        else:         
            prediction = self.crf.decode(emission, mask=attention_mask.type(torch.uint8))
            return prediction

args = TrainingArguments(
    "spanbert_crf_ner-pos2",
    # evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    num_train_epochs=1,
    weight_decay=0.01,
    per_device_train_batch_size=8,
    # per_device_eval_batch_size=32
    fp16=True
    # bf16=True #Ampere GPU
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_data,
    # eval_dataset=train_data,
    # data_collator=data_collator,
    # compute_metrics=compute_metrics,
    tokenizer=tokenizer)

I get error on line **sequence_output = torch.cat((outputs[1][-1], outputs[1][-2], outputs[1][-3], outputs[1][-4]),-1)**

As outputs = self.bert(input_ids, attention_mask=attention_mask) gives the logits for tokenclassification. How can we get hidden states so that I can concate last 4 hidden states. so that I can dooutputs[1][-1]`?

Or is their easier way to implement BERT-CRF model?


Solution

  • i know it's 10 months later, but maybe it helps other guys

    Here is what I used for Trainer and it works in hyperparameter_search too:

    class BERT_CRF_Config(PretrainedConfig):
        model_type = "BERT_CRF"
    
        def __init__(self, **kwarg):
            super().__init__(**kwarg)
            self.model_name = "BERT_CRF"
            self.use_last_n_hidden_states = 1
            self.dropout = 0.5
    
    class BERT_CRF(PreTrainedModel):
        config_class = BERT_CRF_Config
    
        def __init__(self, config):
            super().__init__(config)
    
            bert_config = BertConfig.from_pretrained(config.bert_name)
    
            bert_config.output_attentions = True
            bert_config.output_hidden_states = True
    
            self.bert = AutoModel.from_pretrained(config.bert_name, config=bert_config)
    
            self.dropout = nn.Dropout(p=config.dropout)
    
            self.linear = nn.Linear(
                self.bert.config.hidden_size*config.use_last_n_hidden_states, config.num_labels)
            
            self.crf = CRF(config.num_labels, batch_first=True)
    
        def forward(self,  input_ids = None, attention_mask = None, labels = None,
                    labels_mask=None,  token_type_ids=None, return_dict = None, **kwargs):
    
            if not torch.is_tensor(input_ids):
              input_ids = torch.tensor(input_ids).to(self.device)
    
            if not torch.is_tensor(token_type_ids):
              token_type_ids = torch.tensor(token_type_ids).to(self.device)
    
            if not torch.is_tensor(attention_mask):
              attention_mask = torch.tensor(attention_mask).to(self.device)
    
            if not torch.is_tensor(labels):
              labels = torch.tensor(labels).to(self.device)
    
            if not torch.is_tensor(labels_mask):
              labels_mask = torch.tensor(labels_mask).to(self.device)
    
            bert_output = self.bert(input_ids=input_ids, token_type_ids=token_type_ids, 
                                    attention_mask=attention_mask)
            # last_hidden_layer = bert_output['last_hidden_state']
            # logits = self.linear(last_hidden_layer)
    
            last_hidden_layers = torch.cat(bert_output['hidden_states'][-self.config.use_last_n_hidden_states:], dim=2)
            last_hidden_layers = self.dropout(last_hidden_layers)
            logits = self.linear(last_hidden_layers)
    
            def to_tensor(x):
              x = list(map(lambda y: torch.as_tensor(y), x))
              x = torch.nested.as_nested_tensor(x)
              x = torch.nested.to_padded_tensor(x,padding=0)
    
              x = torch.clamp(x, min=0)
    
              return x
    
            if labels is not None:
              log_likelihood, outputs = (
                                         self.crf(logits, labels, mask=labels_mask.bool()), 
                                         self.crf.decode(logits, mask=labels_mask.bool())
                                        )
              outputs = to_tensor(outputs)
              loss = -log_likelihood
              if not return_dict:
                return loss, outputs
              else:
                return TokenClassifierOutput(
                    loss=loss,
                    logits=outputs,
                    hidden_states=bert_output.hidden_states,
                    attentions=bert_output.attentions,
                )
            
            outputs = self.crf.decode(logits, batch_first=True)
            outputs = to_tensor(outputs)
    
            return outputs
    
        @property
        def device(self):
            return next(self.parameters()).device
    

    and for your hyperparameter search you can use something like this:

    def optuna_hp_space(trial):
        return {
            "learning_rate": trial.suggest_categorical("learning_rate", [1e-5, 2e-5, 2e-5, 4e-5, 5e-5, 6e-5]),
            "warmup_ratio": trial.suggest_categorical("warmup_ratio", [0, 0.1, 0.2, 0.3]),
            "weight_decay": trial.suggest_categorical("weight_decay", [1e-6, 1e-5, 1e-4]),
            "max_grad_norm": trial.suggest_categorical("max_grad_norm", [8, 9,10,11]),
        }
    
    def model_init_crf(trial):
        config = BERT_CRF_Config.from_pretrained(BERT_MODEL, num_labels=NR_LABELS, )
        config.bert_name = BERT_MODEL
        config.dropout = trial.suggest_categorical("dropout", [0, 0.10,  0.30,  0.50])
        config.use_last_n_hidden_states = trial.suggest_categorical("last_n_hidden_states",
                                                              range(1, config.num_hidden_layers+1))
    
        model = BERT_CRF(config).to('cuda')
        return model
    
    best_trial = trainer.hyperparameter_search(
        direction="maximize",
        backend="optuna",
        hp_space=optuna_hp_space,
        n_trials=50,
        compute_objective=my_objective,
    )