mirror of
https://github.com/mii443/tokenizers.git
synced 2025-12-05 12:18:20 +00:00
Expose the trainer to Python bindings.
This commit is contained in:
committed by
Anthony MOI
parent
52082b5476
commit
558e76f18e
@@ -40,3 +40,20 @@ class SentencePieceUnigramTokenizer(BaseTokenizer):
|
||||
}
|
||||
|
||||
super().__init__(tokenizer, parameters)
|
||||
|
||||
def train(
|
||||
self,
|
||||
files: Union[str, List[str]],
|
||||
vocab_size: int = 8000,
|
||||
show_progress: bool = True,
|
||||
special_tokens: List[Union[str, AddedToken]] = [],
|
||||
):
|
||||
""" Train the model using the given files """
|
||||
|
||||
trainer = trainers.UnigramTrainer(
|
||||
vocab_size=vocab_size, special_tokens=special_tokens, show_progress=show_progress,
|
||||
)
|
||||
|
||||
if isinstance(files, str):
|
||||
files = [files]
|
||||
self._tokenizer.train(trainer, files)
|
||||
|
||||
@@ -3,3 +3,4 @@ from .. import trainers
|
||||
Trainer = trainers.Trainer
|
||||
BpeTrainer = trainers.BpeTrainer
|
||||
WordPieceTrainer = trainers.WordPieceTrainer
|
||||
UnigramTrainer = trainers.UnigramTrainer
|
||||
|
||||
@@ -111,3 +111,32 @@ class WordPieceTrainer(Trainer):
|
||||
Trainer
|
||||
"""
|
||||
pass
|
||||
|
||||
class UnigramTrainer(Trainer):
|
||||
""" UnigramTrainer
|
||||
|
||||
Capable of training a Unigram model
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int = 8000,
|
||||
show_progress: bool = True,
|
||||
special_tokens: List[Union[str, AddedToken]] = [],
|
||||
) -> Trainer:
|
||||
""" Instantiate a new UnigramTrainer with the given options:
|
||||
|
||||
Args:
|
||||
vocab_size: unsigned int:
|
||||
The size of the final vocabulary, including all tokens and alphabet.
|
||||
|
||||
show_progress: boolean:
|
||||
Whether to show progress bars while training.
|
||||
|
||||
special_tokens: List[Union[str, AddedToken]]:
|
||||
A list of special tokens the model should know of.
|
||||
|
||||
Returns:
|
||||
Trainer
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -41,6 +41,7 @@ fn trainers(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||
m.add_class::<trainers::PyTrainer>()?;
|
||||
m.add_class::<trainers::PyBpeTrainer>()?;
|
||||
m.add_class::<trainers::PyWordPieceTrainer>()?;
|
||||
m.add_class::<trainers::PyUnigramTrainer>()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
15
bindings/python/tests/bindings/test_trainers.py
Normal file
15
bindings/python/tests/bindings/test_trainers.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import os
|
||||
import pytest
|
||||
import pickle
|
||||
|
||||
from tokenizers import SentencePieceUnigramTokenizer
|
||||
|
||||
|
||||
class TestUnigram:
|
||||
def test_train(self):
|
||||
tokenizer = SentencePieceUnigramTokenizer()
|
||||
tokenizer.train("tests/data/unigram_wagahaiwa_nekodearu.txt", show_progress=False)
|
||||
|
||||
filename = "tests/data/unigram_trained.json"
|
||||
tokenizer.save(filename)
|
||||
os.remove(filename)
|
||||
Reference in New Issue
Block a user