diff --git a/bindings/python/py_src/tokenizers/__init__.py b/bindings/python/py_src/tokenizers/__init__.py index efd57429..dfbf0333 100644 --- a/bindings/python/py_src/tokenizers/__init__.py +++ b/bindings/python/py_src/tokenizers/__init__.py @@ -89,7 +89,6 @@ from .tokenizers import ( pre_tokenizers, processors, trainers, - __version__, ) from .implementations import ( BertWordPieceTokenizer, diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs index 561f1adf..989106d3 100644 --- a/tokenizers/src/tokenizer/added_vocabulary.rs +++ b/tokenizers/src/tokenizer/added_vocabulary.rs @@ -66,6 +66,12 @@ impl AddedToken { self.normalized = normalized; self } + /// Specify whether this token is special, meaning if it should be skipped when decoding + #[must_use] + pub fn special(mut self, special: bool) -> Self { + self.special = special; + self + } } impl Default for AddedToken { fn default() -> Self { @@ -192,6 +198,11 @@ impl AddedVocabulary { &self.added_tokens_map } + /// Get the additional vocabulary with the AddedTokens + pub fn get_vocab_r(&self) -> &HashMap { + &self.added_tokens_map_r + } + /// Get the id matching one of our token if it exists pub fn token_to_id(&self, token: &str, model: &impl Model) -> Option { self.added_tokens_map @@ -248,26 +259,32 @@ impl AddedVocabulary { ignored += 1; continue; } - + + // If a token is already part of the vocabulary, we mark it as added let id = if let Some(id) = self.token_to_id(&token.content, model) { - ignored += 1; id } else { let new_id = (model.get_vocab_size() + self.added_tokens_map.len()) as u32; - self.added_tokens_map.insert(token.content.clone(), new_id); - + new_id + }; + + if self.added_tokens_map_r.values().any(| val| val == token){ + // We only ignore if the AddedToken is already part of the added_tokens_map_r + ignored +=1; + } + else{ + // Make sure we modify the previous entry + self.added_tokens_map.entry(token.content.clone()).and_modify(|old_id| *old_id = id).or_insert_with(||id); if !self.special_tokens_set.contains(&token.content) { self.added_tokens.push(token.clone()); } + // Update the current revert operation + self.added_tokens_map_r + .entry(id) + .and_modify(|t| *t = token.clone()) + .or_insert_with(|| token.clone()); + } - new_id - }; - - // Update the current revert operation - self.added_tokens_map_r - .entry(id) - .and_modify(|t| *t = token.clone()) - .or_insert_with(|| token.clone()); } self.refresh_added_tokens(model, normalizer); @@ -585,12 +602,12 @@ mod tests { ); assert_eq!(vocab.len(), 2); - // Does not add tokens already covered by the model - assert_eq!( - vocab.add_tokens(&[AddedToken::from("test", false)], &model, normalizer), - 0 - ); - assert_eq!(vocab.len(), 2); + // Also adds tokens already covered by the model + let added_token= AddedToken::from("test", false); + assert_eq!(vocab.add_tokens(&[added_token.clone()], &model, normalizer),1); + assert_eq!(vocab.len(), 3); + + assert_eq!(vocab.get_vocab_r()[&1], added_token); } #[test] @@ -626,10 +643,16 @@ mod tests { // Can add tokens already covered by the model assert_eq!( vocab.add_special_tokens(&[AddedToken::from("test", true)], &model, normalizer), - 0 + 1 ); - assert_eq!(vocab.len(), 2); // Did not add a new token, since it exist in the original model + assert_eq!(vocab.len(), 3); // New token was added assert!(vocab.is_special_token("test")); + assert_eq!(*vocab.get_vocab_r(), HashMap::from([ + (2,AddedToken::from("added_token_1", true)), + (3,AddedToken::from("added_token_2", true)), + (4,AddedToken::from("test", true)), + ])); + assert!(!vocab.added_tokens_map.contains_key("test")); }