Search code examples
pythonpytorchparquethuggingface-transformers

Stream a local parquet file to huggingface trainer with an Iterable Dataset


I would like to stream a large .parquet file that I have locally to train a classification model. My script only seems to load the 1st mini batch: the number of epochs increases very quickly even though the file is very large, 1 epoch should last about ten hours. Here is the code I use:

import pyarrow.parquet as pq
import torch
import pandas as pd
import evaluate
from transformers import (AutoTokenizer, CamembertForSequenceClassification,
                          EarlyStoppingCallback, Trainer, TrainingArguments,
                          pipeline)
import numpy as np

class MyIterableDataset(torch.utils.data.IterableDataset):

    def __init__(self, parquet_file_path: str, tokenizer, label_encoder, batch_size: int = 8):
        self.parquet_file = pq.ParquetFile(parquet_file_path)
        self.generator = self.parquet_file.iter_batches(batch_size=batch_size)
        self.tokenizer = tokenizer
        self.label_encoder = label_encoder

    def __iter__(self):
        """ """
        data = next(self.generator)
        encodings = self.tokenizer(data['text'].tolist(), truncation=True, padding=True, max_length=512)
        items = []
        for idx in range(len(data)):
            item = {key: torch.tensor(val[idx]) for key, val in encodings.items()}
            item["labels"] = torch.tensor(self.label_encoder.transform([str(data['target'][idx])]))
            items.append(item)
        return iter(items)
path_train_parquet = '...'
path_dev_parquet = '...'

tokenizer = AutoTokenizer.from_pretrained("camembert-base")

model = CamembertForSequenceClassification.from_pretrained("camembert-base", num_labels=4)
metric = evaluate.load("f1")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels, average='macro')

train_dataset = MyIterableDataset(path_train_parquet, tokenizer, label_encoder, batch_size)
dev_dataset = MyIterableDataset(path_dev_parquet, tokenizer, label_encoder, batch_size)

training_args = TrainingArguments(
    output_dir=path_output_model,
    num_train_epochs=1,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    warmup_steps=10,
    weight_decay=0.01,
    logging_dir=path_logging_dir_model,
    logging_steps=10,
    load_best_model_at_end=True,
    evaluation_strategy = 'steps',
    eval_steps=200,
    save_total_limit = 5,
    save_steps=200,
    report_to='none',
    max_steps=100000
            )

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=dev_dataset,
    compute_metrics=compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience=3)]
                )
trainer.train()
trainer.save_model(...)

Solution

  • The __iter__ method doesn't iterate over the entire dataset because it lacks a loop to repeatedly fetch the next batch of data and process it. Instead, it loads the first batch using next(self.generator), processes it, and then returns an iterator containing the items from that batch. Since it only executes once, you get only the first batch in your dataset.

    You could try something like this:

    def __iter__(self):
        while True:
            try:
                data = next(self.generator) #try and get the next bit of data
            except StopIteration:
                # End of the dataset, break
                break
    
            encodings = self.tokenizer(data['text'].tolist(), truncation=True, padding=True, max_length=512)
            items = []
            for idx in range(len(data)): #for index encode and yield
                item = {key: torch.tensor(val[idx]) for key, val in encodings.items()}
                item["labels"] = torch.tensor(self.label_encoder.transform([str(data['target'][idx])]))
                items.append(item)
            yield from items 
    

    This version should let the __iter__ method keep fetching batches from self.generator, processing them, and yielding individual items from each batch until there are no more batches left in the .parquet file.

    https://www.datacamp.com/tutorial/python-iterators-generators-tutorial

    https://anandology.com/python-practice-book/iterators.html

    https://www.geeksforgeeks.org/difference-between-iterator-vs-generator/