diff --git a/bindings/python/py_src/tokenizers/implementations/sentencepiece_unigram.py b/bindings/python/py_src/tokenizers/implementations/sentencepiece_unigram.py index 74c39662..fc9ef2ee 100644 --- a/bindings/python/py_src/tokenizers/implementations/sentencepiece_unigram.py +++ b/bindings/python/py_src/tokenizers/implementations/sentencepiece_unigram.py @@ -40,3 +40,20 @@ class SentencePieceUnigramTokenizer(BaseTokenizer): } super().__init__(tokenizer, parameters) + + def train( + self, + files: Union[str, List[str]], + vocab_size: int = 8000, + show_progress: bool = True, + special_tokens: List[Union[str, AddedToken]] = [], + ): + """ Train the model using the given files """ + + trainer = trainers.UnigramTrainer( + vocab_size=vocab_size, special_tokens=special_tokens, show_progress=show_progress, + ) + + if isinstance(files, str): + files = [files] + self._tokenizer.train(trainer, files) diff --git a/bindings/python/py_src/tokenizers/trainers/__init__.py b/bindings/python/py_src/tokenizers/trainers/__init__.py index 2ec085ff..9ddbff2b 100644 --- a/bindings/python/py_src/tokenizers/trainers/__init__.py +++ b/bindings/python/py_src/tokenizers/trainers/__init__.py @@ -3,3 +3,4 @@ from .. import trainers Trainer = trainers.Trainer BpeTrainer = trainers.BpeTrainer WordPieceTrainer = trainers.WordPieceTrainer +UnigramTrainer = trainers.UnigramTrainer diff --git a/bindings/python/py_src/tokenizers/trainers/__init__.pyi b/bindings/python/py_src/tokenizers/trainers/__init__.pyi index d355b9b7..1042c494 100644 --- a/bindings/python/py_src/tokenizers/trainers/__init__.pyi +++ b/bindings/python/py_src/tokenizers/trainers/__init__.pyi @@ -111,3 +111,32 @@ class WordPieceTrainer(Trainer): Trainer """ pass + +class UnigramTrainer(Trainer): + """ UnigramTrainer + + Capable of training a Unigram model + """ + + def __init__( + self, + vocab_size: int = 8000, + show_progress: bool = True, + special_tokens: List[Union[str, AddedToken]] = [], + ) -> Trainer: + """ Instantiate a new UnigramTrainer with the given options: + + Args: + vocab_size: unsigned int: + The size of the final vocabulary, including all tokens and alphabet. + + show_progress: boolean: + Whether to show progress bars while training. + + special_tokens: List[Union[str, AddedToken]]: + A list of special tokens the model should know of. + + Returns: + Trainer + """ + pass diff --git a/bindings/python/src/lib.rs b/bindings/python/src/lib.rs index 27e01076..2616883e 100644 --- a/bindings/python/src/lib.rs +++ b/bindings/python/src/lib.rs @@ -41,6 +41,7 @@ fn trainers(_py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; Ok(()) } diff --git a/bindings/python/tests/bindings/test_trainers.py b/bindings/python/tests/bindings/test_trainers.py new file mode 100644 index 00000000..ae8bb7c5 --- /dev/null +++ b/bindings/python/tests/bindings/test_trainers.py @@ -0,0 +1,15 @@ +import os +import pytest +import pickle + +from tokenizers import SentencePieceUnigramTokenizer + + +class TestUnigram: + def test_train(self): + tokenizer = SentencePieceUnigramTokenizer() + tokenizer.train("tests/data/unigram_wagahaiwa_nekodearu.txt", show_progress=False) + + filename = "tests/data/unigram_trained.json" + tokenizer.save(filename) + os.remove(filename)