From 1a90cc96e50af1b11ae3b7b2425a511b2b86fb93 Mon Sep 17 00:00:00 2001 From: Anthony MOI Date: Mon, 16 Dec 2019 18:45:26 -0500 Subject: [PATCH] Python - Can add tokens --- bindings/python/src/tokenizer.rs | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index b7f38c91..bb89768c 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -122,6 +122,31 @@ impl Tokenizer { self.tokenizer.id_to_token(id) } + fn add_tokens(&mut self, tokens: &PyList) -> PyResult { + let tokens = tokens + .into_iter() + .map(|token| { + if let Ok(content) = token.extract::() { + Ok(tk::tokenizer::AddedToken { + content, + ..Default::default() + }) + } else if let Ok((content, single_word)) = token.extract::<(String, bool)>() { + Ok(tk::tokenizer::AddedToken { + content, + single_word, + }) + } else { + Err(exceptions::Exception::py_err( + "Input must be a list[str] or list[(str, bool)]", + )) + } + }) + .collect::>>()?; + + Ok(self.tokenizer.add_tokens(&tokens)) + } + fn train(&mut self, trainer: &Trainer, files: Vec) -> PyResult<()> { trainer.trainer.execute(|trainer| { if let Err(e) = self.tokenizer.train(trainer, files) {