diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index e00edfd0..167f1680 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -667,10 +667,8 @@ impl PyTokenizer { } fn train(&mut self, trainer: &PyTrainer, files: Vec) -> PyResult<()> { - self.tokenizer = self - .tokenizer - .clone() - .train(trainer, files) + self.tokenizer + .train_and_replace(trainer, files) .map_err(|e| exceptions::Exception::py_err(format!("{}", e)))?; Ok(()) } diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index 03dfcf49..54b5a044 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -1004,7 +1004,7 @@ where Ok(words) } - /// Train a model and replace our current Model, using the given Trainer + /// Train a model and return a new Tokenizer, using the given Trainer pub fn train( self, trainer: &T, @@ -1032,6 +1032,20 @@ where Ok(new_tok) } + + /// Train a model and replace our current Model, using the given Trainer + pub fn train_and_replace(&mut self, trainer: &T, files: Vec) -> Result<()> + where + T: Trainer + Sync, + { + let words = self.word_count(trainer, files)?; + + let (model, special_tokens) = trainer.train(words)?; + self.model = model; + self.add_special_tokens(&special_tokens); + + Ok(()) + } } impl std::str::FromStr for TokenizerImpl