diff --git a/bindings/python/tests/documentation/test_tutorial_train_from_iterators.py b/bindings/python/tests/documentation/test_tutorial_train_from_iterators.py index bba55a48..58d93351 100644 --- a/bindings/python/tests/documentation/test_tutorial_train_from_iterators.py +++ b/bindings/python/tests/documentation/test_tutorial_train_from_iterators.py @@ -70,8 +70,10 @@ class TestTrainFromIterators: # START def_batch_iterator def batch_iterator(batch_size=1000): - for i in range(0, len(dataset), batch_size): - yield dataset[i : i + batch_size]["text"] + # Only keep the text column to avoid decoding the rest of the columns unnecessarily + tok_dataset = dataset.select_columns("text") + for batch in tok_dataset.iter(batch_size): + yield batch["text"] # END def_batch_iterator