Python - Make the trainer optional on Tokenizer.train

This commit is contained in:
Anthony MOI
2020-10-07 21:25:32 -04:00
committed by Anthony MOI
parent c230183cf6
commit 224862fe0c
7 changed files with 15 additions and 12 deletions

View File

@ -8,7 +8,7 @@ use pyo3::types::*;
use pyo3::PyObjectProtocol;
use tk::models::bpe::BPE;
use tk::tokenizer::{
PaddingDirection, PaddingParams, PaddingStrategy, PostProcessor, TokenizerImpl,
Model, PaddingDirection, PaddingParams, PaddingStrategy, PostProcessor, TokenizerImpl,
TruncationParams, TruncationStrategy,
};
use tokenizers as tk;
@ -1039,10 +1039,13 @@ impl PyTokenizer {
Ok(self.tokenizer.add_special_tokens(&tokens))
}
fn train(&mut self, trainer: &PyTrainer, files: Vec<String>) -> PyResult<()> {
let gil = Python::acquire_gil();
gil.python()
.allow_threads(|| ToPyResult(self.tokenizer.train_and_replace(trainer, files)).into())
#[args(trainer = "None")]
fn train(&mut self, files: Vec<String>, trainer: Option<&PyTrainer>) -> PyResult<()> {
let trainer =
trainer.map_or_else(|| self.tokenizer.get_model().get_trainer(), |t| t.clone());
Python::with_gil(|py| {
py.allow_threads(|| ToPyResult(self.tokenizer.train_and_replace(&trainer, files)).into())
})
}
/// Apply all the post-processing steps to the given encodings.