I get the error:
Exception has occurred: AttributeError (note: full exception trace is shown but execution is paused at: _run_module_as_main)
'Dataset' object has no attribute 'take'
File "/lfs/ampere1/0/brando9/beyond-scale-language-data-diversity/src/diversity/div_coeff.py", line 499, in experiment_compute_diveristy_coeff_single_dataset_then_combined_datasets_with_domain_weights
batch = dataset.take(batch_size)
File "/lfs/ampere1/0/brando9/beyond-scale-language-data-diversity/src/diversity/div_coeff.py", line 552, in <module>
experiment_compute_diveristy_coeff_single_dataset_then_combined_datasets_with_domain_weights()
File "/lfs/ampere1/0/brando9/miniconda/envs/beyond_scale/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/lfs/ampere1/0/brando9/miniconda/envs/beyond_scale/lib/python3.10/runpy.py", line 196, in _run_module_as_main (Current frame)
return _run_code(code, main_globals, None,
AttributeError: 'Dataset' object has no attribute 'take'
only happens with streaming = False
. How to fix? I do NOT want to stream the data and I want .take to work.
This seem to have worked:
print(f'{dataset=}')
print(f'{type(dataset)=}')
# datasets.iterable_dataset.IterableDataset
# datasets.arrow_dataset.Dataset
dataset = IterableDataset(dataset) if type(dataset) != IterableDataset else dataset # to force dataset.take(batch_size) to work in non-streaming mode
batch = dataset.take(batch_size)
seems to work? Takes long to fetch batch.
If the custom collate fns actually worked this would be simple since the collate fn would receive a batch of batch_size already. See it doesn't work here: How to use huggingface HF trainer train with custom collate function? with trainer. Actually but I don't want to use the trainer...
Might be related to issues where I try to steam the data set but HF seems to kill my requests?
2654 httplib_response = self._make_request(
2655 File "/lfs/hyperturing1/0/brando9/miniconda/envs/beyond_scale/lib/python3.10/site-packages/urllib3/connectionpool.py", line 466, in _make_request
2656 six.raise_from(e, None)
2657 File "<string>", line 3, in raise_from
2658 File "/lfs/hyperturing1/0/brando9/miniconda/envs/beyond_scale/lib/python3.10/site-packages/urllib3/connectionpool.py", line 461, in _make_request
2659 httplib_response = conn.getresponse()
2660 File "/lfs/hyperturing1/0/brando9/miniconda/envs/beyond_scale/lib/python3.10/http/client.py", line 1375, in getresponse
2661 response.begin()
2662 File "/lfs/hyperturing1/0/brando9/miniconda/envs/beyond_scale/lib/python3.10/http/client.py", line 318, in begin
2663 version, status, reason = self._read_status()
2664 File "/lfs/hyperturing1/0/brando9/miniconda/envs/beyond_scale/lib/python3.10/http/client.py", line 287, in _read_status
2665 raise RemoteDisconnected("Remote end closed connection without"
2666 http.client.RemoteDisconnected: Remote end closed connection without response
2667 During handling of the above exception, another exception occurred:
2668 Traceback (most recent call last):
2669 File "/lfs/hyperturing1/0/brando9/miniconda/envs/beyond_scale/lib/python3.10/site-packages/requests/adapters.py", line 486, in send
2670 resp = conn.urlopen(
2671 File "/lfs/hyperturing1/0/brando9/miniconda/envs/beyond_scale/lib/python3.10/site-packages/urllib3/connectionpool.py", line 798, in urlopen
2672 retries = retries.increment(
2673 File "/lfs/hyperturing1/0/brando9/miniconda/envs/beyond_scale/lib/python3.10/site-packages/urllib3/util/retry.py", line 550, in increment
2674 raise six.reraise(type(error), error, _stacktrace)
2675 File "/lfs/hyperturing1/0/brando9/miniconda/envs/beyond_scale/lib/python3.10/site-packages/urllib3/packages/six.py", line 769, in reraise
2676 raise value.with_traceback(tb)
2677 File "/lfs/hyperturing1/0/brando9/miniconda/envs/beyond_scale/lib/python3.10/site-packages/urllib3/connectionpool.py", line 714, in urlopen
2678 httplib_response = self._make_request(
2679 File "/lfs/hyperturing1/0/brando9/miniconda/envs/beyond_scale/lib/python3.10/site-packages/urllib3/connectionpool.py", line 466, in _make_request
2680 six.raise_from(e, None)
2681 File "<string>", line 3, in raise_from
2682 File "/lfs/hyperturing1/0/brando9/miniconda/envs/beyond_scale/lib/python3.10/site-packages/urllib3/connectionpool.py", line 461, in _make_request
2683 httplib_response = conn.getresponse()
2684 File "/lfs/hyperturing1/0/brando9/miniconda/envs/beyond_scale/lib/python3.10/http/client.py", line 1375, in getresponse
2685 response.begin()
2686 File "/lfs/hyperturing1/0/brando9/miniconda/envs/beyond_scale/lib/python3.10/http/client.py", line 318, in begin
2687 version, status, reason = self._read_status()
2688 File "/lfs/hyperturing1/0/brando9/miniconda/envs/beyond_scale/lib/python3.10/http/client.py", line 287, in _read_status
2689 raise RemoteDisconnected("Remote end closed connection without"
2690 urllib3.exceptions.ProtocolError: ('Connection aborted.', RemoteDisconnected('Remote end closed connection without response'))
2691 During handling of the above exception, another exception occurred:
2692 Traceback (most recent call last):
2693 File "/lfs/hyperturing1/0/brando9/beyond-scale-language-data-diversity/src/diversity/div_coeff.py", line 578, in <module>
2694 # -- Finish wandb
2695 File "/lfs/hyperturing1/0/brando9/beyond-scale-language-data-diversity/src/diversity/div_coeff.py", line 540, in experiment_compute_diveristy_coeff_single_dataset_then_combined_datasets_with_domain_weights
2696 print(f'{batch=}')
2697 File "/lfs/hyperturing1/0/brando9/beyond-scale-language-data-diversity/src/diversity/div_coeff.py", line 63, in get_diversity_coefficient
2698 embedding, loss = Task2Vec(probe_network, classifier_opts={'seed': seed}).embed(tokenized_batch)
2699 File "/afs/cs.stanford.edu/u/brando9/beyond-scale-language-data-diversity/src/diversity/task2vec.py", line 133, in embed
2700 loss = self._finetune_classifier(dataset, loader_opts=self.loader_opts, classifier_opts=self.classifier_opts, max_samples=self.max_samples, epochs=epochs)
2701 File "/afs/cs.stanford.edu/u/brando9/beyond-scale-language-data-diversity/src/diversity/task2vec.py", line 198, in _finetune_classifier
2702 for step, batch in enumerate(epoch_iterator):
2703 File "/lfs/hyperturing1/0/brando9/miniconda/envs/beyond_scale/lib/python3.10/site-packages/tqdm/std.py", line 1182, in __iter__
2704 for obj in iterable:
2705 File "/lfs/hyperturing1/0/brando9/miniconda/envs/beyond_scale/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 633, in __next__
2706 data = self._next_data()
2707 File "/lfs/hyperturing1/0/brando9/miniconda/envs/beyond_scale/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 677, in _next_data
2708 data = self._dataset_fetcher.fetch(index) # may raise StopIteration
2709 File "/lfs/hyperturing1/0/brando9/miniconda/envs/beyond_scale/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 32, in fetch
2710 data.append(next(self.dataset_iter))
2711 File "/lfs/hyperturing1/0/brando9/miniconda/envs/beyond_scale/lib/python3.10/site-packages/datasets/iterable_dataset.py", line 1353, in __iter__
2712 for key, example in ex_iterable:
2713 File "/lfs/hyperturing1/0/brando9/miniconda/envs/beyond_scale/lib/python3.10/site-packages/datasets/iterable_dataset.py", line 652, in __iter__
2714 yield from self._iter()
2715 File "/lfs/hyperturing1/0/brando9/miniconda/envs/beyond_scale/lib/python3.10/site-packages/datasets/iterable_dataset.py", line 667, in _iter
2716 for key, example in iterator:
Seems solution for now is:
batch = dataset.select(range(512))
or
sample_data = train_dataset.select(range(batch_size)) if not isinstance(train_dataset, datasets.iterable_dataset.IterableDataset) else train_dataset.take(batch_size)
or
raw_text_batch = shuffled_dataset.take(batch_size) if streaming else shuffled_dataset.select(random.sample(batch_size, batch_size))
or for random inidices:
for batch_num in range(num_batches):
# - Get batch
shuffled_dataset = dataset.shuffle(buffer_size=buffer_size, seed=seed) if shuffle else dataset
raw_text_batch = shuffled_dataset.take(batch_size) if streaming else shuffled_dataset.select(random.sample(batch_size, batch_size))
tokenized_batch = map(raw_text_batch)
if verbose:
print(f'{raw_text_batch=}')
print(f'{tokenized_batch=}')