mirror of
https://github.com/mii443/tokenizers.git
synced 2025-12-07 05:08:24 +00:00
Expose the trainer to Python bindings.
This commit is contained in:
committed by
Anthony MOI
parent
52082b5476
commit
558e76f18e
@@ -40,3 +40,20 @@ class SentencePieceUnigramTokenizer(BaseTokenizer):
|
|||||||
}
|
}
|
||||||
|
|
||||||
super().__init__(tokenizer, parameters)
|
super().__init__(tokenizer, parameters)
|
||||||
|
|
||||||
|
def train(
|
||||||
|
self,
|
||||||
|
files: Union[str, List[str]],
|
||||||
|
vocab_size: int = 8000,
|
||||||
|
show_progress: bool = True,
|
||||||
|
special_tokens: List[Union[str, AddedToken]] = [],
|
||||||
|
):
|
||||||
|
""" Train the model using the given files """
|
||||||
|
|
||||||
|
trainer = trainers.UnigramTrainer(
|
||||||
|
vocab_size=vocab_size, special_tokens=special_tokens, show_progress=show_progress,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(files, str):
|
||||||
|
files = [files]
|
||||||
|
self._tokenizer.train(trainer, files)
|
||||||
|
|||||||
@@ -3,3 +3,4 @@ from .. import trainers
|
|||||||
Trainer = trainers.Trainer
|
Trainer = trainers.Trainer
|
||||||
BpeTrainer = trainers.BpeTrainer
|
BpeTrainer = trainers.BpeTrainer
|
||||||
WordPieceTrainer = trainers.WordPieceTrainer
|
WordPieceTrainer = trainers.WordPieceTrainer
|
||||||
|
UnigramTrainer = trainers.UnigramTrainer
|
||||||
|
|||||||
@@ -111,3 +111,32 @@ class WordPieceTrainer(Trainer):
|
|||||||
Trainer
|
Trainer
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
class UnigramTrainer(Trainer):
|
||||||
|
""" UnigramTrainer
|
||||||
|
|
||||||
|
Capable of training a Unigram model
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size: int = 8000,
|
||||||
|
show_progress: bool = True,
|
||||||
|
special_tokens: List[Union[str, AddedToken]] = [],
|
||||||
|
) -> Trainer:
|
||||||
|
""" Instantiate a new UnigramTrainer with the given options:
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vocab_size: unsigned int:
|
||||||
|
The size of the final vocabulary, including all tokens and alphabet.
|
||||||
|
|
||||||
|
show_progress: boolean:
|
||||||
|
Whether to show progress bars while training.
|
||||||
|
|
||||||
|
special_tokens: List[Union[str, AddedToken]]:
|
||||||
|
A list of special tokens the model should know of.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Trainer
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|||||||
@@ -41,6 +41,7 @@ fn trainers(_py: Python, m: &PyModule) -> PyResult<()> {
|
|||||||
m.add_class::<trainers::PyTrainer>()?;
|
m.add_class::<trainers::PyTrainer>()?;
|
||||||
m.add_class::<trainers::PyBpeTrainer>()?;
|
m.add_class::<trainers::PyBpeTrainer>()?;
|
||||||
m.add_class::<trainers::PyWordPieceTrainer>()?;
|
m.add_class::<trainers::PyWordPieceTrainer>()?;
|
||||||
|
m.add_class::<trainers::PyUnigramTrainer>()?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
15
bindings/python/tests/bindings/test_trainers.py
Normal file
15
bindings/python/tests/bindings/test_trainers.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
from tokenizers import SentencePieceUnigramTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class TestUnigram:
|
||||||
|
def test_train(self):
|
||||||
|
tokenizer = SentencePieceUnigramTokenizer()
|
||||||
|
tokenizer.train("tests/data/unigram_wagahaiwa_nekodearu.txt", show_progress=False)
|
||||||
|
|
||||||
|
filename = "tests/data/unigram_trained.json"
|
||||||
|
tokenizer.save(filename)
|
||||||
|
os.remove(filename)
|
||||||
Reference in New Issue
Block a user