I am following Huggingfaces Tutorial on fine-tuning a model. Unfortunately, they only show the procedure for fine-tuning BERT to a classifier by providing labeled data. My case is a bit different: I want to fine-tune gpt-2 to generate text in a specific writing style. So my input would be just the text (in that style) without any label. I have tried the code below but that doesn't work well and results in very bad quality that includes many special characters.
training_args = TrainingArguments(
output_dir="./results",
overwrite_output_dir=True,
num_train_epochs=3,
per_device_train_batch_size=4,
save_steps=10_000,
save_total_limit=2,
prediction_loss_only=True
)
data_collator = DataCollatorForLanguageModeling(
tokenizer=gen_tokenizer,
mlm=False, # Suggestion from ChatGPT
)
# Initialize the Trainer
trainer = Trainer(
model=gen_model,
args=training_args,
data_collator=data_collator,
train_dataset=tokenized_dataset,
)
trainer.train()
Is there anything I should consder/change in my code? I am grateful for any answer because I couldn't find anything online
If you want to train your model to generate new text in a style similar to that of your texts, then this is Causal Language Modeling.
There is a separate page dedicated to this topic on HuggingFace: https://huggingface.co/docs/transformers/en/tasks/language_modeling.
Or, if you want a complete guide, there is a beautiful article on Medium on how to fine-tune the GPT-2: https://medium.com/@prashanth.ramanathan/fine-tuning-a-pre-trained-gpt-2-model-and-performing-inference-a-hands-on-guide-57c097a3b810. The dataset is wikitext (without labels) and the code sample looks like this:
# Define training arguments
training_args = TrainingArguments(
output_dir='/mnt/disks/disk1/results',
evaluation_strategy='epoch',
num_train_epochs=1,
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
warmup_steps=500,
weight_decay=0.01,
logging_dir='/mnt/disks/disk1/logs'
)
# Initialize Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_datasets['train'],
eval_dataset=tokenized_datasets['validation'],
)