Search code examples
pythonpytorchhuggingface-transformershuggingface

XLNet or BERT Chinese for HuggingFace AutoModelForSeq2SeqLM Training


I want to use the pre-trained XLNet (xlnet-base-cased, which the model type is Text Generation) or BERT Chinese (bert-base-chinese, which the model type is Fill Mask) for Sequence to Sequence Language Model (Seq2SeqLM) training.

I can use facebook/bart-large (which the model type is Feature Extraction) for constructing the Seq2SeqLM, but not the 2 pretrained models mentioned above. Here is my code below:

Load Dataset

from datasets import load_dataset
yuezh = load_dataset("my-custom-dataset")

Sample data from the dataset my-custom-dataset

{"translation": {"yue": "又睇", "zh": "再看"}}
{"translation": {"yue": "初頭", "zh": "開始的時候"}}

Create Test Split

yuezh = yuezh["train"].train_test_split(test_size=0.2)
print(yuezh)

Output:

DatasetDict({
    train: Dataset({
        features: ['translation'],
        num_rows: 4246
    })
    test: Dataset({
        features: ['translation'],
        num_rows: 1062
    })
})

Tokenizer

from transformers import AutoTokenizer
checkpoint = 'bert-base-chinese'
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

Pre-process Function

def preprocess_function(examples):
    inputs = [prefix + example[source_lang] for example in examples["translation"]]
    targets = [example[target_lang] for example in examples["translation"]]
    model_inputs = tokenizer(inputs, text_target=targets, max_length=128, truncation=True)
    return model_inputs

Parameters

source_lang = "yue"
target_lang = "zh"
prefix = "Translate this: "
tokenized_yuezh = yuezh.map(preprocess_function, batched=True)
tokenized_yuezh = tokenized_yuezh.remove_columns(yuezh["train"].column_names)
print(tokenized_yuezh)

Output:

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 4246
    })
    test: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 1062
    })
})

Evaluate Performance

# Use ScareBLEU to evaluate the performance
import evaluate
metric = evaluate.load("sacrebleu")

Data Collator

from transformers import DataCollatorForSeq2Seq
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=checkpoint)

Supporting Functions

import numpy as np

def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [[label.strip()] for label in labels]

    return preds, labels


def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result = metric.compute(predictions=decoded_preds, references=decoded_labels)
    result = {"bleu": result["score"]}

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
    result["gen_len"] = np.mean(prediction_lens)
    result = {k: round(v, 4) for k, v in result.items()}
    return result

Training

model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
training_args = CustomSeq2SeqTrainingArguments(
    output_dir="my-output-dir",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=2,
    predict_with_generate=True,
    remove_unused_columns=False,
    fp16=True,
    push_to_hub=False, # Don't push to Hub yet
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_yuezh["train"],
    eval_dataset=tokenized_yuezh["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()

It results in the following errors:

ValueError: Unrecognized configuration class <class 
'transformers.models.bert.configuration_bert.BertConfig'> for this kind of AutoModel: 
AutoModelForSeq2SeqLM.
Model type should be one of BartConfig, BigBirdPegasusConfig, BlenderbotConfig, 
BlenderbotSmallConfig, EncoderDecoderConfig, FSMTConfig, LEDConfig, LongT5Config, 
M2M100Config, MarianConfig, MBartConfig, MT5Config, MvpConfig, PegasusConfig, PegasusXConfig,
PLBartConfig, ProphetNetConfig, SwitchTransformersConfig, T5Config, XLMProphetNetConfig.

The AutoModelForSeq2SeqLM does not support XLNet or BERT. What should I do to make the Sequence-to-Sequence training work?


Solution

  • xlnet-base-cased bert-base-chinese can not be loaded directly with AutoModelForSeq2SeqLM because it expects a model that can perform seq2seq tasks.

    But you can leverage these checkpoints for a seq2seq model thanks to this paper and the EncoderDecoderModel class:

    from transformers import EncoderDecoderModel, AutoModelForSeq2SeqLM
    # This will use the weights of bert-base-chinese for the encoder and the decoder
    model = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-chinese", "bert-base-chinese")
    model.config.decoder_start_token_id = tokenizer.cls_token_id
    
    # You can later load it as AutoModelForSeq2SeqLM 
    #model.save_pretrained("my_seq2seqbert")
    #model1 = AutoModelForSeq2SeqLM.from_pretrained("my_seq2seqbert")
    

    I haven't tested it, but it seems XLNet can not be used as a decoder according to this issue. In this case, you can try to use a decoder like gpt with cross-attention:

    model = EncoderDecoderModel.from_encoder_decoder_pretrained("xlnet-base-cased", "gpt2")
    #model.save_pretrained("my_seq2seqxlnet-base-cased")
    #model1 = AutoModelForSeq2SeqLM.from_pretrained("my_seq2seqxlnet-base-cased")