I want to use the pre-trained XLNet (xlnet-base-cased
, which the model type is Text Generation) or BERT Chinese (bert-base-chinese
, which the model type is Fill Mask) for Sequence to Sequence Language Model (Seq2SeqLM
) training.
I can use facebook/bart-large
(which the model type is Feature Extraction) for constructing the Seq2SeqLM
, but not the 2 pretrained models mentioned above. Here is my code below:
Load Dataset
from datasets import load_dataset
yuezh = load_dataset("my-custom-dataset")
Sample data from the dataset my-custom-dataset
{"translation": {"yue": "又睇", "zh": "再看"}}
{"translation": {"yue": "初頭", "zh": "開始的時候"}}
Create Test Split
yuezh = yuezh["train"].train_test_split(test_size=0.2)
print(yuezh)
Output:
DatasetDict({
train: Dataset({
features: ['translation'],
num_rows: 4246
})
test: Dataset({
features: ['translation'],
num_rows: 1062
})
})
Tokenizer
from transformers import AutoTokenizer
checkpoint = 'bert-base-chinese'
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
Pre-process Function
def preprocess_function(examples):
inputs = [prefix + example[source_lang] for example in examples["translation"]]
targets = [example[target_lang] for example in examples["translation"]]
model_inputs = tokenizer(inputs, text_target=targets, max_length=128, truncation=True)
return model_inputs
Parameters
source_lang = "yue"
target_lang = "zh"
prefix = "Translate this: "
tokenized_yuezh = yuezh.map(preprocess_function, batched=True)
tokenized_yuezh = tokenized_yuezh.remove_columns(yuezh["train"].column_names)
print(tokenized_yuezh)
Output:
DatasetDict({
train: Dataset({
features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
num_rows: 4246
})
test: Dataset({
features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
num_rows: 1062
})
})
Evaluate Performance
# Use ScareBLEU to evaluate the performance
import evaluate
metric = evaluate.load("sacrebleu")
Data Collator
from transformers import DataCollatorForSeq2Seq
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=checkpoint)
Supporting Functions
import numpy as np
def postprocess_text(preds, labels):
preds = [pred.strip() for pred in preds]
labels = [[label.strip()] for label in labels]
return preds, labels
def compute_metrics(eval_preds):
preds, labels = eval_preds
if isinstance(preds, tuple):
preds = preds[0]
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
result = metric.compute(predictions=decoded_preds, references=decoded_labels)
result = {"bleu": result["score"]}
prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
result["gen_len"] = np.mean(prediction_lens)
result = {k: round(v, 4) for k, v in result.items()}
return result
Training
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
training_args = CustomSeq2SeqTrainingArguments(
output_dir="my-output-dir",
evaluation_strategy="epoch",
learning_rate=2e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
weight_decay=0.01,
save_total_limit=3,
num_train_epochs=2,
predict_with_generate=True,
remove_unused_columns=False,
fp16=True,
push_to_hub=False, # Don't push to Hub yet
)
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=tokenized_yuezh["train"],
eval_dataset=tokenized_yuezh["test"],
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
)
trainer.train()
It results in the following errors:
ValueError: Unrecognized configuration class <class
'transformers.models.bert.configuration_bert.BertConfig'> for this kind of AutoModel:
AutoModelForSeq2SeqLM.
Model type should be one of BartConfig, BigBirdPegasusConfig, BlenderbotConfig,
BlenderbotSmallConfig, EncoderDecoderConfig, FSMTConfig, LEDConfig, LongT5Config,
M2M100Config, MarianConfig, MBartConfig, MT5Config, MvpConfig, PegasusConfig, PegasusXConfig,
PLBartConfig, ProphetNetConfig, SwitchTransformersConfig, T5Config, XLMProphetNetConfig.
The AutoModelForSeq2SeqLM
does not support XLNet or BERT. What should I do to make the Sequence-to-Sequence training work?
xlnet-base-cased bert-base-chinese can not be loaded directly with AutoModelForSeq2SeqLM because it expects a model that can perform seq2seq tasks.
But you can leverage these checkpoints for a seq2seq model thanks to this paper and the EncoderDecoderModel class:
from transformers import EncoderDecoderModel, AutoModelForSeq2SeqLM
# This will use the weights of bert-base-chinese for the encoder and the decoder
model = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-chinese", "bert-base-chinese")
model.config.decoder_start_token_id = tokenizer.cls_token_id
# You can later load it as AutoModelForSeq2SeqLM
#model.save_pretrained("my_seq2seqbert")
#model1 = AutoModelForSeq2SeqLM.from_pretrained("my_seq2seqbert")
I haven't tested it, but it seems XLNet can not be used as a decoder according to this issue. In this case, you can try to use a decoder like gpt with cross-attention:
model = EncoderDecoderModel.from_encoder_decoder_pretrained("xlnet-base-cased", "gpt2")
#model.save_pretrained("my_seq2seqxlnet-base-cased")
#model1 = AutoModelForSeq2SeqLM.from_pretrained("my_seq2seqxlnet-base-cased")