I’m trying to filter a dataset based on the ids in a list. This approach is too slow. The dataset is an Arrow dataset. Import data from huggingface.
import numpy as np
from datasets import load_dataset, DatasetDict
from collections import Counter
import pyarrow as pa
import pandas as pd
responses = load_dataset('peixian/rtGender', 'responses', split = 'train')
# post_id_test_list contains list of ids
responses_test = responses.filter(lambda x: x['post_id'] in post_id_test_list)
The dataset you get from load_dataset
isn't an arrow Dataset but a hugging face Dataset. It is backed by an arrow table though.
Applying a lambda filter is going to be slow, if you want a faster vertorized operation you could try to modify the underlying arrow Table directly:
import pyarrow as pa
import pyarrow.compute as compute
table = responses.data
flags = compute.is_in(table['post_id'], value_set=pa.array(post_id_test_list, pa.int32()))
filtered_table = table.filter(flags)
filtered_respoonse = datasets.DataSet(filtered_table, response.info, response.split)
Though I'm not 100% sure if the last line is the correct way to recreate your dataset using an arrow table.