Given a datasets.iterable_dataset.IterableDataset
with stream=True
, e.g.
train_data = load_dataset("csv", data_files="../input/tatoeba/tatoeba-sentpairs.tsv",
streaming=True, delimiter="\t", split="train")
and trying to use it in a Trainer
object, e.g.
# instantiate trainer
trainer = Seq2SeqTrainer(
model=multibert,
tokenizer=tokenizer,
args=training_args,
train_dataset=train_data,
eval_dataset=train_data,
)
trainer.train()
It throws an error:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
/tmp/ipykernel_27/3002801805.py in <module>
28 )
29
---> 30 trainer.train()
/opt/conda/lib/python3.7/site-packages/transformers/trainer.py in train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
1411 resume_from_checkpoint=resume_from_checkpoint,
1412 trial=trial,
-> 1413 ignore_keys_for_eval=ignore_keys_for_eval,
1414 )
1415
/opt/conda/lib/python3.7/site-packages/transformers/trainer.py in _inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
1623
1624 step = -1
-> 1625 for step, inputs in enumerate(epoch_iterator):
1626
1627 # Skip past any already trained steps if resuming training
/opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py in __next__(self)
528 if self._sampler_iter is None:
529 self._reset()
--> 530 data = self._next_data()
531 self._num_yielded += 1
532 if self._dataset_kind == _DatasetKind.Iterable and \
/opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py in _next_data(self)
567
568 def _next_data(self):
--> 569 index = self._next_index() # may raise StopIteration
570 data = self._dataset_fetcher.fetch(index) # may raise StopIteration
571 if self._pin_memory:
/opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py in _next_index(self)
519
520 def _next_index(self):
--> 521 return next(self._sampler_iter) # may raise StopIteration
522
523 def _next_data(self):
/opt/conda/lib/python3.7/site-packages/torch/utils/data/sampler.py in __iter__(self)
224 def __iter__(self) -> Iterator[List[int]]:
225 batch = []
--> 226 for idx in self.sampler:
227 batch.append(idx)
228 if len(batch) == self.batch_size:
/opt/conda/lib/python3.7/site-packages/torch/utils/data/sampler.py in __iter__(self)
64
65 def __iter__(self) -> Iterator[int]:
---> 66 return iter(range(len(self.data_source)))
67
68 def __len__(self) -> int:
TypeError: object of type 'IterableDataset' has no len()
This can be resolved by wrapping the IterableDataset
object with the IterableWrapper
from torchdata
library.
from torchdata.datapipes.iter import IterDataPipe, IterableWrapper
...
# instantiate trainer
trainer = Seq2SeqTrainer(
model=multibert,
tokenizer=tokenizer,
args=training_args,
train_dataset=IterableWrapper(train_data),
eval_dataset=IterableWrapper(train_data),
)
trainer.train()
IterableDataset
with Seq2SeqTrainer
without casting it with IterableWrapper
?For reference, a full working code would look something as below, replacing the line where train_dataset=IterableWrapper(train_data)
to train_dataset=train_data
will replicate the TypeError: object of type 'IterableDataset' has no len()
error.
import torch
from datasets import load_dataset
from transformers import EncoderDecoderModel
from transformers import AutoTokenizer
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
from torchdata.datapipes.iter import IterDataPipe, IterableWrapper
multibert = EncoderDecoderModel.from_encoder_decoder_pretrained(
"bert-base-multilingual-uncased", "bert-base-multilingual-uncased"
)
tokenizer= AutoTokenizer.from_pretrained("bert-base-multilingual-uncased")
tokenizer.bos_token = tokenizer.cls_token
tokenizer.eos_token = tokenizer.sep_token
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
# set special tokens
multibert.config.decoder_start_token_id = tokenizer.bos_token_id
multibert.config.eos_token_id = tokenizer.eos_token_id
multibert.config.pad_token_id = tokenizer.pad_token_id
# sensible parameters for beam search
multibert.config.vocab_size = multibert.config.decoder.vocab_size
def process_data_to_model_inputs(batch, max_len=10):
inputs = tokenizer(batch["SRC"], padding="max_length",
truncation=True, max_length=max_len)
outputs = tokenizer(batch["TRG"], padding="max_length",
truncation=True, max_length=max_len)
batch["input_ids"] = inputs.input_ids
batch["attention_mask"] = inputs.attention_mask
batch["decoder_input_ids"] = outputs.input_ids
batch["decoder_attention_mask"] = outputs.attention_mask
batch["labels"] = outputs.input_ids.copy()
# because BERT automatically shifts the labels, the labels correspond exactly to `decoder_input_ids`.
# We have to make sure that the PAD token is ignored
batch["labels"] = [[-100 if token == tokenizer.pad_token_id else token for token in labels] for labels in batch["labels"]]
return batch
# tatoeba-sentpairs.tsv is a pretty large file.
train_data = load_dataset("csv", data_files="../input/tatoeba/tatoeba-sentpairs.tsv",
streaming=True, delimiter="\t", split="train")
train_data = ds.map(process_data_to_model_inputs, batched=True)
batch_size = 1
# set training arguments - these params are not really tuned, feel free to change
training_args = Seq2SeqTrainingArguments(
output_dir="./",
evaluation_strategy="steps",
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
predict_with_generate=True,
logging_steps=2, # set to 1000 for full training
save_steps=16, # set to 500 for full training
eval_steps=4, # set to 8000 for full training
warmup_steps=1, # set to 2000 for full training
max_steps=16, # delete for full training
# overwrite_output_dir=True,
save_total_limit=1,
#fp16=True,
)
# instantiate trainer
trainer = Seq2SeqTrainer(
model=multibert,
tokenizer=tokenizer,
args=training_args,
train_dataset=IterableWrapper(train_data),
eval_dataset=IterableWrapper(train_data),
)
trainer.train()
Found the answer from https://discuss.huggingface.co/t/using-iterabledataset-with-trainer-iterabledataset-has-no-len/15790
By adding the with format to the iterable dataset, like this:
train_data.with_format("torch")
The trainer should work without throwing the len()
error.
# instantiate trainer
trainer = Seq2SeqTrainer(
model=multibert,
tokenizer=tokenizer,
args=training_args,
train_dataset=train_data.with_format("torch"),
eval_dataset=train_data.with_format("torch"),
)
trainer.train()