Faster datasets train example

Using .iter() is much faster than accessing using row ids
This commit is contained in:
Quentin Lhoest
2023-03-23 11:24:30 +01:00
committed by GitHub
parent b8fbea00a9
commit e76f900bc0

View File

@ -9,15 +9,15 @@ bpe_tokenizer.pre_tokenizer = pre_tokenizers.Whitespace()
bpe_tokenizer.normalizer = normalizers.Lowercase() bpe_tokenizer.normalizer = normalizers.Lowercase()
# Initialize a dataset # Initialize a dataset
dataset = datasets.load_dataset("wikitext", "wikitext-103-raw-v1") dataset = datasets.load_dataset("wikitext", "wikitext-103-raw-v1", split="train")
# Build an iterator over this dataset # Build an iterator over this dataset
def batch_iterator(): def batch_iterator():
batch_length = 1000 batch_size = 1000
for i in range(0, len(dataset["train"]), batch_length): for batch in dataset.iter(batch_size=batch_size):
yield dataset["train"][i : i + batch_length]["text"] yield batch["text"]
# And finally train # And finally train
bpe_tokenizer.train_from_iterator(batch_iterator(), length=len(dataset["train"])) bpe_tokenizer.train_from_iterator(batch_iterator(), length=len(dataset))