Expose the trainer to Python bindings.

This commit is contained in:
Nicolas Patry
2020-09-01 16:14:12 +02:00
committed by Anthony MOI
parent 52082b5476
commit 558e76f18e
5 changed files with 63 additions and 0 deletions

View File

@@ -40,3 +40,20 @@ class SentencePieceUnigramTokenizer(BaseTokenizer):
}
super().__init__(tokenizer, parameters)
def train(
self,
files: Union[str, List[str]],
vocab_size: int = 8000,
show_progress: bool = True,
special_tokens: List[Union[str, AddedToken]] = [],
):
""" Train the model using the given files """
trainer = trainers.UnigramTrainer(
vocab_size=vocab_size, special_tokens=special_tokens, show_progress=show_progress,
)
if isinstance(files, str):
files = [files]
self._tokenizer.train(trainer, files)

View File

@@ -3,3 +3,4 @@ from .. import trainers
Trainer = trainers.Trainer
BpeTrainer = trainers.BpeTrainer
WordPieceTrainer = trainers.WordPieceTrainer
UnigramTrainer = trainers.UnigramTrainer

View File

@@ -111,3 +111,32 @@ class WordPieceTrainer(Trainer):
Trainer
"""
pass
class UnigramTrainer(Trainer):
""" UnigramTrainer
Capable of training a Unigram model
"""
def __init__(
self,
vocab_size: int = 8000,
show_progress: bool = True,
special_tokens: List[Union[str, AddedToken]] = [],
) -> Trainer:
""" Instantiate a new UnigramTrainer with the given options:
Args:
vocab_size: unsigned int:
The size of the final vocabulary, including all tokens and alphabet.
show_progress: boolean:
Whether to show progress bars while training.
special_tokens: List[Union[str, AddedToken]]:
A list of special tokens the model should know of.
Returns:
Trainer
"""
pass