Search code examples
huggingface-datasets

How does one make dataset.take(512) work with streaming = False with hugging face data set?


I get the error:

Exception has occurred: AttributeError       (note: full exception trace is shown but execution is paused at: _run_module_as_main)
'Dataset' object has no attribute 'take'
  File "/lfs/ampere1/0/brando9/beyond-scale-language-data-diversity/src/diversity/div_coeff.py", line 499, in experiment_compute_diveristy_coeff_single_dataset_then_combined_datasets_with_domain_weights
    batch = dataset.take(batch_size)
  File "/lfs/ampere1/0/brando9/beyond-scale-language-data-diversity/src/diversity/div_coeff.py", line 552, in <module>
    experiment_compute_diveristy_coeff_single_dataset_then_combined_datasets_with_domain_weights()
  File "/lfs/ampere1/0/brando9/miniconda/envs/beyond_scale/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/lfs/ampere1/0/brando9/miniconda/envs/beyond_scale/lib/python3.10/runpy.py", line 196, in _run_module_as_main (Current frame)
    return _run_code(code, main_globals, None,
AttributeError: 'Dataset' object has no attribute 'take'

only happens with streaming = False. How to fix? I do NOT want to stream the data and I want .take to work.


idea: convert HF Dataset to HF datasets.iterable_dataset.IterableDataset

This seem to have worked:

    print(f'{dataset=}')
    print(f'{type(dataset)=}')
    # datasets.iterable_dataset.IterableDataset
    # datasets.arrow_dataset.Dataset
    dataset = IterableDataset(dataset) if type(dataset) != IterableDataset else dataset  # to force dataset.take(batch_size) to work in non-streaming mode
    batch = dataset.take(batch_size)

seems to work? Takes long to fetch batch.


Idea 2: Collate fn

If the custom collate fns actually worked this would be simple since the collate fn would receive a batch of batch_size already. See it doesn't work here: How to use huggingface HF trainer train with custom collate function? with trainer. Actually but I don't want to use the trainer...


Might be related to issues where I try to steam the data set but HF seems to kill my requests?

2654     httplib_response = self._make_request(
2655   File "/lfs/hyperturing1/0/brando9/miniconda/envs/beyond_scale/lib/python3.10/site-packages/urllib3/connectionpool.py", line 466, in _make_request
2656     six.raise_from(e, None)
2657   File "<string>", line 3, in raise_from
2658   File "/lfs/hyperturing1/0/brando9/miniconda/envs/beyond_scale/lib/python3.10/site-packages/urllib3/connectionpool.py", line 461, in _make_request
2659     httplib_response = conn.getresponse()
2660   File "/lfs/hyperturing1/0/brando9/miniconda/envs/beyond_scale/lib/python3.10/http/client.py", line 1375, in getresponse
2661     response.begin()
2662   File "/lfs/hyperturing1/0/brando9/miniconda/envs/beyond_scale/lib/python3.10/http/client.py", line 318, in begin
2663     version, status, reason = self._read_status()
2664   File "/lfs/hyperturing1/0/brando9/miniconda/envs/beyond_scale/lib/python3.10/http/client.py", line 287, in _read_status
2665     raise RemoteDisconnected("Remote end closed connection without"
2666 http.client.RemoteDisconnected: Remote end closed connection without response
2667 During handling of the above exception, another exception occurred:
2668 Traceback (most recent call last):
2669   File "/lfs/hyperturing1/0/brando9/miniconda/envs/beyond_scale/lib/python3.10/site-packages/requests/adapters.py", line 486, in send
2670     resp = conn.urlopen(
2671   File "/lfs/hyperturing1/0/brando9/miniconda/envs/beyond_scale/lib/python3.10/site-packages/urllib3/connectionpool.py", line 798, in urlopen
2672     retries = retries.increment(
2673   File "/lfs/hyperturing1/0/brando9/miniconda/envs/beyond_scale/lib/python3.10/site-packages/urllib3/util/retry.py", line 550, in increment
2674     raise six.reraise(type(error), error, _stacktrace)
2675   File "/lfs/hyperturing1/0/brando9/miniconda/envs/beyond_scale/lib/python3.10/site-packages/urllib3/packages/six.py", line 769, in reraise
2676     raise value.with_traceback(tb)
2677   File "/lfs/hyperturing1/0/brando9/miniconda/envs/beyond_scale/lib/python3.10/site-packages/urllib3/connectionpool.py", line 714, in urlopen
2678     httplib_response = self._make_request(
2679   File "/lfs/hyperturing1/0/brando9/miniconda/envs/beyond_scale/lib/python3.10/site-packages/urllib3/connectionpool.py", line 466, in _make_request
2680     six.raise_from(e, None)
2681   File "<string>", line 3, in raise_from
2682   File "/lfs/hyperturing1/0/brando9/miniconda/envs/beyond_scale/lib/python3.10/site-packages/urllib3/connectionpool.py", line 461, in _make_request
2683     httplib_response = conn.getresponse()
2684   File "/lfs/hyperturing1/0/brando9/miniconda/envs/beyond_scale/lib/python3.10/http/client.py", line 1375, in getresponse
2685     response.begin()
2686   File "/lfs/hyperturing1/0/brando9/miniconda/envs/beyond_scale/lib/python3.10/http/client.py", line 318, in begin
2687     version, status, reason = self._read_status()
2688   File "/lfs/hyperturing1/0/brando9/miniconda/envs/beyond_scale/lib/python3.10/http/client.py", line 287, in _read_status
2689     raise RemoteDisconnected("Remote end closed connection without"
2690 urllib3.exceptions.ProtocolError: ('Connection aborted.', RemoteDisconnected('Remote end closed connection without response'))
2691 During handling of the above exception, another exception occurred:
2692 Traceback (most recent call last):
2693   File "/lfs/hyperturing1/0/brando9/beyond-scale-language-data-diversity/src/diversity/div_coeff.py", line 578, in <module>
2694     # -- Finish wandb
2695   File "/lfs/hyperturing1/0/brando9/beyond-scale-language-data-diversity/src/diversity/div_coeff.py", line 540, in experiment_compute_diveristy_coeff_single_dataset_then_combined_datasets_with_domain_weights
2696     print(f'{batch=}')
2697   File "/lfs/hyperturing1/0/brando9/beyond-scale-language-data-diversity/src/diversity/div_coeff.py", line 63, in get_diversity_coefficient
2698     embedding, loss = Task2Vec(probe_network, classifier_opts={'seed': seed}).embed(tokenized_batch)
2699   File "/afs/cs.stanford.edu/u/brando9/beyond-scale-language-data-diversity/src/diversity/task2vec.py", line 133, in embed
2700     loss = self._finetune_classifier(dataset, loader_opts=self.loader_opts, classifier_opts=self.classifier_opts, max_samples=self.max_samples, epochs=epochs)
2701   File "/afs/cs.stanford.edu/u/brando9/beyond-scale-language-data-diversity/src/diversity/task2vec.py", line 198, in _finetune_classifier
2702     for step, batch in enumerate(epoch_iterator):
2703   File "/lfs/hyperturing1/0/brando9/miniconda/envs/beyond_scale/lib/python3.10/site-packages/tqdm/std.py", line 1182, in __iter__
2704     for obj in iterable:
2705   File "/lfs/hyperturing1/0/brando9/miniconda/envs/beyond_scale/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 633, in __next__
2706     data = self._next_data()
2707   File "/lfs/hyperturing1/0/brando9/miniconda/envs/beyond_scale/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 677, in _next_data
2708     data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
2709   File "/lfs/hyperturing1/0/brando9/miniconda/envs/beyond_scale/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 32, in fetch
2710     data.append(next(self.dataset_iter))
2711   File "/lfs/hyperturing1/0/brando9/miniconda/envs/beyond_scale/lib/python3.10/site-packages/datasets/iterable_dataset.py", line 1353, in __iter__
2712     for key, example in ex_iterable:
2713   File "/lfs/hyperturing1/0/brando9/miniconda/envs/beyond_scale/lib/python3.10/site-packages/datasets/iterable_dataset.py", line 652, in __iter__
2714     yield from self._iter()
2715   File "/lfs/hyperturing1/0/brando9/miniconda/envs/beyond_scale/lib/python3.10/site-packages/datasets/iterable_dataset.py", line 667, in _iter
2716     for key, example in iterator:



Solution

  • Seems solution for now is:

    batch = dataset.select(range(512))
    

    or

        sample_data = train_dataset.select(range(batch_size)) if not isinstance(train_dataset, datasets.iterable_dataset.IterableDataset) else train_dataset.take(batch_size)
    

    or

            raw_text_batch = shuffled_dataset.take(batch_size) if streaming else shuffled_dataset.select(random.sample(batch_size, batch_size))
    

    or for random inidices:

        for batch_num in range(num_batches):
            # - Get batch
            shuffled_dataset = dataset.shuffle(buffer_size=buffer_size, seed=seed) if shuffle else dataset
            raw_text_batch = shuffled_dataset.take(batch_size) if streaming else shuffled_dataset.select(random.sample(batch_size, batch_size))
            tokenized_batch = map(raw_text_batch)
            if verbose:
                print(f'{raw_text_batch=}')
                print(f'{tokenized_batch=}')