Search code examples
pythonpyarrow

pyarrow Table Filtering -- huggingface


I’m trying to filter a dataset based on the ids in a list. This approach is too slow. The dataset is an Arrow dataset. Import data from huggingface.

import numpy as np
from datasets import load_dataset, DatasetDict
from collections import Counter
import pyarrow as pa
import pandas as pd


responses = load_dataset('peixian/rtGender', 'responses', split = 'train')
# post_id_test_list contains list of ids
responses_test = responses.filter(lambda x: x['post_id'] in post_id_test_list)

Solution

  • The dataset you get from load_dataset isn't an arrow Dataset but a hugging face Dataset. It is backed by an arrow table though.

    Applying a lambda filter is going to be slow, if you want a faster vertorized operation you could try to modify the underlying arrow Table directly:

    import pyarrow as pa
    import pyarrow.compute as compute
    
    
    table = responses.data
    
    flags = compute.is_in(table['post_id'], value_set=pa.array(post_id_test_list, pa.int32()))
    filtered_table = table.filter(flags)
    
    filtered_respoonse = datasets.DataSet(filtered_table, response.info, response.split)
    

    Though I'm not 100% sure if the last line is the correct way to recreate your dataset using an arrow table.