Search code examples
pythonnlphuggingface-transformersfine-tuning

Fine tune in using Huggingface


I am totally new bee in NLP and I want to fine tune a pre-trained model (rebel-large) with my onw dataset. After reading some tutorials I frealized there is no difference between training and fine tuning? I mean I understand they are not the same, but in the code there are no difference.

How can I define how I want to fine tune the model? E.g. if I just want to change the weights in the last layer, how should I make the trained to do it?


Solution

  • You're right, fine-tuning a model is the same as loading a pre-trained model and then training it. You can use the following snippet, and replace the dataset with yours.

    from datasets import load_dataset
    from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
    import numpy as np
    import evaluate
    
    tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
    model = AutoModelForSequenceClassification.from_pretrained('bert-base-uncased')
    metric = evaluate.load('accuracy')
    
    
    def tokenize_function(examples):
        return tokenizer(examples['text'], padding='max_length', truncation=True)
    
    
    def compute_metrics(eval_pred):
        logits, labels = eval_pred
        predictions = np.argmax(logits, axis=-1)
        return metric.compute(predictions=predictions, references=labels)
    
    
    dataset = load_dataset('imdb')
    dataset = dataset.map(tokenize_function, batched=True)
    
    # print(model)  # Use to discover your layers and choose which ones to put in `to_freeze`
    
    to_freeze = [model.bert.encoder]
    
    for layer in to_freeze:
        for param in layer.parameters():
            param.requires_grad = False
    
    training_args = TrainingArguments(output_dir='test_trainer')
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=dataset['train'],
        eval_dataset=dataset['test']
    )
    trainer.train()