Search code examples
pythonpython-3.xpytorchhuggingface-transformers

My `collate_fn` function got empty data when pass it to collate_fn parameter of Trainer function


I am trying to do fine-tuning an existing hugging face model.

The below code is what I collected from some documents

from transformers import AutoTokenizer, AutoModelForQuestionAnswering, TrainingArguments, Trainer
import torch

# Load the Vietnamese model and tokenizer
model_name = "vinai/phobert-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForQuestionAnswering.from_pretrained(model_name)

# Define the training data
train_dataset = [
    {
        "question": "What is your name ?",
        "context": "My name is Peter",
        "answer": {
            "text": "Peter",
            "start": 7,
            "end": 11
        }
    }
]

# Define the validation data
val_dataset = [
    {
        "question": "What is your name ?",
        "context": "My name is Peter",
        "answer": {
            "text": "Peter",
            "start": 7,
            "end": 11
        }
    }
]

# Define the training arguments
training_args = TrainingArguments(
    output_dir='./results',
    evaluation_strategy='epoch',
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    weight_decay=0.01,
)


# Define the data collator
def collate_fn(data):
    input_ids = torch.stack([item.get('input_ids', None) for item in data if 'input_ids' in item])
    attention_mask = torch.stack([item.get('attention_mask', None) for item in data if 'attention_mask' in item])
    start_positions = torch.stack([item.get('start_positions', None) for item in data if 'start_positions' in item])
    end_positions = torch.stack([item.get('end_positions', None) for item in data if 'end_positions' in item])
    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'start_positions': start_positions,
        'end_positions': end_positions
    }

# Define the trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=collate_fn
)

# Fine-tune the model
trainer.train()

I keep receiving the error of

    input_ids = torch.stack([item.get('input_ids', None) for item in data if 'input_ids' in item])
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: stack expects a non-empty TensorList

I try to do

def collate_fn(data):
    print(data)

but I got []


Solution

  • There is only one example on the train_dataset, so try with batch size equal to 1.