Search code examples
tensorflownlpgpt-2

How to save checkpoints for thie transformer gpt2 to continue training?


I am retraining the GPT2 language model, and am following this blog :

https://towardsdatascience.com/train-gpt-2-in-your-own-language-fc6ad4d60171

Here, they have trained a network on GPT2, and I am trying to recreate a same. However, my dataset is too large(250Mb), so I want to continue training in intervals. In other words, I want to checkpoint the model training. How could I do this?


Solution

  • training_args = TrainingArguments(
        output_dir=model_checkpoint,
        # other hyper-params
    )
    
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_set,
        eval_dataset=dev_set,
        tokenizer=tokenizer
    )
    
    trainer.train()
    # Save the model to model_dir
    trainer.save_model()
    
    def prepare_model(tokenizer, model_name_path):
        model = AutoModelForCausalLM.from_pretrained(model_name_path)
        model.resize_token_embeddings(len(tokenizer))
        return model
    
    # Assume tokenizer is defined, You can simply pass the saved model directory path.
    model = prepare_model(tokenizer, model_checkpoint)