diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index 75c1a175..3b9b6cac 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -233,6 +233,10 @@ impl Tokenizer { Ok(self.tokenizer.add_tokens(&tokens)) } + fn add_special_tokens(&mut self, tokens: Vec<&str>) -> PyResult { + Ok(self.tokenizer.add_special_tokens(&tokens)) + } + fn train(&mut self, trainer: &Trainer, files: Vec) -> PyResult<()> { trainer.trainer.execute(|trainer| { if let Err(e) = self.tokenizer.train(trainer, files) { diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index 631d6293..9603eebd 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -140,6 +140,7 @@ pub struct Tokenizer { added_tokens: HashMap, added_tokens_r: HashMap, split_re: Option, + special_tokens: HashMap, // General processing parameters trunc: Option, @@ -159,6 +160,7 @@ impl Tokenizer { added_tokens: HashMap::new(), added_tokens_r: HashMap::new(), split_re: None, + special_tokens: HashMap::new(), trunc: None, padding: None, @@ -483,6 +485,24 @@ impl Tokenizer { Ok(final_encoding) } + /// Register the given tokens as special tokens. This is especially useful for removing + /// these special tokens while decoding + pub fn add_special_tokens(&mut self, tokens: &[&str]) -> usize { + let added_tokens = tokens + .iter() + .map(|t| AddedToken::from((*t).to_owned())) + .collect::>(); + + let added = self.add_tokens(&added_tokens); + for token in tokens { + if let Some(id) = self.token_to_id(token) { + self.special_tokens.entry((*token).to_owned()).or_insert(id); + } + } + + added + } + /// Add the given tokens to the added vocabulary pub fn add_tokens(&mut self, tokens: &[AddedToken]) -> usize { let mut ignored = 0;