Add in-place train.

This commit is contained in:
Sebastian Pütz
2020-08-04 20:32:10 +02:00
committed by Anthony MOI
parent ac8af63f70
commit 10a39ba6b4
2 changed files with 17 additions and 5 deletions

View File

@ -667,10 +667,8 @@ impl PyTokenizer {
} }
fn train(&mut self, trainer: &PyTrainer, files: Vec<String>) -> PyResult<()> { fn train(&mut self, trainer: &PyTrainer, files: Vec<String>) -> PyResult<()> {
self.tokenizer = self self.tokenizer
.tokenizer .train_and_replace(trainer, files)
.clone()
.train(trainer, files)
.map_err(|e| exceptions::Exception::py_err(format!("{}", e)))?; .map_err(|e| exceptions::Exception::py_err(format!("{}", e)))?;
Ok(()) Ok(())
} }

View File

@ -1004,7 +1004,7 @@ where
Ok(words) Ok(words)
} }
/// Train a model and replace our current Model, using the given Trainer /// Train a model and return a new Tokenizer, using the given Trainer
pub fn train<T, TM>( pub fn train<T, TM>(
self, self,
trainer: &T, trainer: &T,
@ -1032,6 +1032,20 @@ where
Ok(new_tok) Ok(new_tok)
} }
/// Train a model and replace our current Model, using the given Trainer
pub fn train_and_replace<T>(&mut self, trainer: &T, files: Vec<String>) -> Result<()>
where
T: Trainer<Model = M> + Sync,
{
let words = self.word_count(trainer, files)?;
let (model, special_tokens) = trainer.train(words)?;
self.model = model;
self.add_special_tokens(&special_tokens);
Ok(())
}
} }
impl<M, N, PT, PP, D> std::str::FromStr for TokenizerImpl<M, N, PT, PP, D> impl<M, N, PT, PP, D> std::str::FromStr for TokenizerImpl<M, N, PT, PP, D>