Search code examples
pythonhuggingface-transformershuggingfacehuggingface-datasets

How does one fix an interleaved data set from only sampling one data set?


The following

from datasets import load_dataset
from datasets import interleave_datasets

# Preprocess each dataset
c4 = load_dataset("c4", "en", split="train", streaming=True) 
wikitext = load_dataset("wikitext", "wikitext-103-v1", split="train", streaming=True)

# Interleave the preprocessed datasets  
datasets = [c4, wikitext]
for dataset in datasets:
  print(dataset.description)
interleaved = interleave_datasets(datasets, probabilities=[0.5, 0.5])
print(interleaved)

only samples from one data set, why?

example.keys()=dict_keys(['text', 'timestamp', 'url'])
example.keys()=dict_keys(['text', 'timestamp', 'url'])
example.keys()=dict_keys(['text', 'timestamp', 'url'])
example.keys()=dict_keys(['text', 'timestamp', 'url'])
example.keys()=dict_keys(['text', 'timestamp', 'url'])
example.keys()=dict_keys(['text', 'timestamp', 'url'])
example.keys()=dict_keys(['text', 'timestamp', 'url'])
example.keys()=dict_keys(['text', 'timestamp', 'url'])
example.keys()=dict_keys(['text', 'timestamp', 'url'])
example.keys()=dict_keys(['text', 'timestamp', 'url'])
example.keys()=dict_keys(['text', 'timestamp', 'url'])
example.keys()=dict_keys(['text', 'timestamp', 'url'])
example.keys()=dict_keys(['text', 'timestamp', 'url'])
example.keys()=dict_keys(['text', 'timestamp', 'url'])
example.keys()=dict_keys(['text', 'timestamp', 'url'])
example.keys()=dict_keys(['text', 'timestamp', 'url'])
example.keys()=dict_keys(['text', 'timestamp', 'url'])
example.keys()=dict_keys(['text', 'timestamp', 'url'])
example.keys()=dict_keys(['text', 'timestamp', 'url'])
example.keys()=dict_keys(['text', 'timestamp', 'url'])
example.keys()=dict_keys(['text', 'timestamp', 'url'])
example.keys()=dict_keys(['text', 'timestamp', 'url'])
counts=100

colab: https://colab.research.google.com/drive/1VIR66U1d7qk3Q1vU_URoo5tHEEheORpN?usp=sharing


cross:


Solution

  • The interleave_datasets function works correctly here, it's your conclusion that is incorrect. What happens is that when two datasets are interleaved, their features are combined.

    These are the features of c4 and wikitext:

    print(c4.column_names)
    
    >>> ['text', 'timestamp', 'url']
    
    print(wikitext.column_names)
    
    >>> ['text']
    

    When you combine the datasets, all examples in the new dataset will have features ['text', 'timestamp', 'url'], even if they come from wikitext dataset. Since wikitext dataset does not have features timestamp and url, these will be None.

    Dummy example:

    from datasets import Dataset, interleave_datasets
    d1 = Dataset.from_dict({
      'feature_1': ['A', 'B', 'C']
    })
    d2 = Dataset.from_dict({
      'feature_2': [1, 2, 3]
    })
    
    dataset = interleave_datasets([d1, d2], probabilities=[0.5, 0.5], seed=42)
    print('Features:', dataset.column_names)
    
    for e in dataset:
      print(e)
    

    Output:

    Features: ['feature_1', 'feature_2']
    {'feature_1': None, 'feature_2': 1}
    {'feature_1': 'A', 'feature_2': None}
    {'feature_1': None, 'feature_2': 2}
    {'feature_1': None, 'feature_2': 3}