Doc - Add documentation for training from iterators

This commit is contained in:
Anthony MOI
2021-01-12 15:30:01 -05:00
committed by Anthony MOI
parent 7bee825238
commit 91dae1de15
6 changed files with 238 additions and 3 deletions

View File

@ -100,7 +100,7 @@ jobs:
run: |
python -m venv .env
source .env/bin/activate
pip install pytest requests setuptools_rust numpy
pip install pytest requests setuptools_rust numpy datasets
python setup.py develop
- name: Check style

View File

@ -0,0 +1,101 @@
from ..utils import data_dir, train_files
import os
import pytest
import datasets
import gzip
class TestTrainFromIterators:
@staticmethod
def get_tokenizer_trainer():
# START init_tokenizer_trainer
from tokenizers import Tokenizer, models, normalizers, pre_tokenizers, decoders, trainers
tokenizer = Tokenizer(models.Unigram())
tokenizer.normalizer = normalizers.NFKC()
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel()
tokenizer.decoders = decoders.ByteLevel()
trainer = trainers.UnigramTrainer(
vocab_size=20000,
initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
special_tokens=["<PAD>", "<BOS>", "<EOS>"],
)
# END init_tokenizer_trainer
trainer.show_progress = False
return tokenizer, trainer
@staticmethod
def load_dummy_dataset():
# START load_dataset
import datasets
dataset = datasets.load_dataset(
"wikitext", "wikitext-103-raw-v1", split="train+test+validation"
)
# END load_dataset
@pytest.fixture(scope="class")
def setup_gzip_files(self, train_files):
with open(train_files["small"], "rt") as small:
for n in range(3):
path = f"data/my-file.{n}.gz"
with gzip.open(path, "wt") as f:
f.write(small.read())
def test_train_basic(self):
tokenizer, trainer = self.get_tokenizer_trainer()
# START train_basic
# First few lines of the "Zen of Python" https://www.python.org/dev/peps/pep-0020/
data = [
"Beautiful is better than ugly."
"Explicit is better than implicit."
"Simple is better than complex."
"Complex is better than complicated."
"Flat is better than nested."
"Sparse is better than dense."
"Readability counts."
]
tokenizer.train_from_iterator(data, trainer=trainer)
# END train_basic
def test_datasets(self):
tokenizer, trainer = self.get_tokenizer_trainer()
# In order to keep tests fast, we only use the first 100 examples
os.environ["TOKENIZERS_PARALLELISM"] = "true"
dataset = datasets.load_dataset("wikitext", "wikitext-103-raw-v1", split="train[0:100]")
# START def_batch_iterator
def batch_iterator(batch_size=1000):
for i in range(0, len(dataset), batch_size):
yield dataset[i : i + batch_size]["text"]
# END def_batch_iterator
# START train_datasets
tokenizer.train_from_iterator(batch_iterator(), trainer=trainer, length=len(dataset))
# END train_datasets
def test_gzip(self, setup_gzip_files):
tokenizer, trainer = self.get_tokenizer_trainer()
# START single_gzip
import gzip
with gzip.open("data/my-file.0.gz", "rt") as f:
tokenizer.train_from_iterator(f, trainer=trainer)
# END single_gzip
# START multi_gzip
files = ["data/my-file.0.gz", "data/my-file.1.gz", "data/my-file.2.gz"]
def gzip_iterator():
for path in files:
with gzip.open(path, "rt") as f:
for line in f:
yield line
tokenizer.train_from_iterator(gzip_iterator(), trainer=trainer)
# END multi_gzip

View File

@ -0,0 +1,29 @@
import re
from sphinx.directives.other import TocTree
class TocTreeTags(TocTree):
hasPat = re.compile("^\s*:(.+):(.+)$")
def filter_entries(self, entries):
filtered = []
for e in entries:
m = self.hasPat.match(e)
if m != None:
if self.env.app.tags.has(m.groups()[0]):
filtered.append(m.groups()[1])
else:
filtered.append(e)
return filtered
def run(self):
self.content = self.filter_entries(self.content)
return super().run()
def setup(app):
app.add_directive("toctree-tags", TocTreeTags)
return {
"version": "0.1",
}

View File

@ -39,7 +39,7 @@ rust_version = "latest"
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = ["sphinx.ext.autodoc", "sphinx.ext.napoleon", "entities", "rust_doc"]
extensions = ["sphinx.ext.autodoc", "sphinx.ext.napoleon", "entities", "rust_doc", "toctree_tags"]
# Add any paths that contain templates here, relative to this directory.
templates_path = ["_templates"]
@ -49,7 +49,6 @@ templates_path = ["_templates"]
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = []
# -- Options for HTML output -------------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for
@ -70,6 +69,10 @@ html_static_path = ["_static"]
def setup(app):
for language in languages:
if not tags.has(language):
exclude_patterns.append(f"tutorials/{language}/*")
app.add_css_file("css/huggingface.css")
app.add_css_file("css/code-snippets.css")
app.add_js_file("js/custom.js")

View File

@ -32,6 +32,13 @@ Main features:
pipeline
components
.. toctree-tags::
:maxdepth: 3
:caption: Using 🤗 Tokenizers
:glob:
:python:tutorials/python/*
.. toctree::
:maxdepth: 3
:caption: API Reference

View File

@ -0,0 +1,95 @@
Training from memory
----------------------------------------------------------------------------------------------------
In the `Quicktour <quicktour.html>`__, we saw how to build and train a tokenizer using text files,
but we can actually use any Python Iterator. In this section we'll see a few different ways of
training our tokenizer.
For all the examples listed below, we'll use the same :class:`~tokenizers.Tokenizer` and
:class:`~tokenizers.trainers.Trainer`, built as following:
.. literalinclude:: ../../../../bindings/python/tests/documentation/test_tutorial_train_from_iterators.py
:language: python
:start-after: START init_tokenizer_trainer
:end-before: END init_tokenizer_trainer
:dedent: 8
This tokenizer is based on the :class:`~tokenizers.models.Unigram` model. It takes care of
normalizing the input using the NFKC Unicode normalization method, and uses a
:class:`~tokenizers.pre_tokenizers.ByteLevel` pre-tokenizer with the corresponding decoder.
For more information on the components used here, you can check `here <components.html>`__
The most basic way
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
As you probably guessed already, the easiest way to train our tokenizer is by using a :obj:`List`:
.. literalinclude:: ../../../../bindings/python/tests/documentation/test_tutorial_train_from_iterators.py
:language: python
:start-after: START train_basic
:end-before: END train_basic
:dedent: 8
Easy, right? You can use anything working as an iterator here, be it a :obj:`List`, :obj:`Tuple`,
or a :obj:`np.Array`. Anything works as long as it provides strings.
Using the 🤗 Datasets library
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
An awesome way to access one of the many datasets that exist out there is by using the 🤗 Datasets
library. For more information about it, you should check
`the official documentation here <https://huggingface.co/docs/datasets/>`__.
Let's start by loading our dataset:
.. literalinclude:: ../../../../bindings/python/tests/documentation/test_tutorial_train_from_iterators.py
:language: python
:start-after: START load_dataset
:end-before: END load_dataset
:dedent: 8
The next step is to build an iterator over this dataset. The easiest way to do this is probably by
using a generator:
.. literalinclude:: ../../../../bindings/python/tests/documentation/test_tutorial_train_from_iterators.py
:language: python
:start-after: START def_batch_iterator
:end-before: END def_batch_iterator
:dedent: 8
As you can see here, for improved efficiency we can actually provide a batch of examples used
to train, instead of iterating over them one by one. By doing so, we can expect performances very
similar to those we got while training directly from files.
With our iterator ready, we just need to launch the training. In order to improve the look of our
progress bars, we can specify the total length of the dataset:
.. literalinclude:: ../../../../bindings/python/tests/documentation/test_tutorial_train_from_iterators.py
:language: python
:start-after: START train_datasets
:end-before: END train_datasets
:dedent: 8
And that's it!
Using gzip files
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Since gzip files in Python can be used as iterators, it is extremely simple to train on such files:
.. literalinclude:: ../../../../bindings/python/tests/documentation/test_tutorial_train_from_iterators.py
:language: python
:start-after: START single_gzip
:end-before: END single_gzip
:dedent: 8
Now if we wanted to train from multiple gzip files, it wouldn't be much harder:
.. literalinclude:: ../../../../bindings/python/tests/documentation/test_tutorial_train_from_iterators.py
:language: python
:start-after: START multi_gzip
:end-before: END multi_gzip
:dedent: 8
And voilà!