diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs index 01ec59e6..a02527f3 100644 --- a/tokenizers/src/tokenizer/added_vocabulary.rs +++ b/tokenizers/src/tokenizer/added_vocabulary.rs @@ -259,7 +259,7 @@ impl AddedVocabulary { ignored += 1; continue; } - + // If a token is already part of the vocabulary, we mark it as added let id = if let Some(id) = self.token_to_id(&token.content, model) { id @@ -267,14 +267,16 @@ impl AddedVocabulary { let new_id = (model.get_vocab_size() + self.added_tokens_map.len()) as u32; new_id }; - - if self.added_tokens_map_r.values().any(| val| val == token){ + + if self.added_tokens_map_r.values().any(|val| val == token) { // We only ignore if the AddedToken is already part of the added_tokens_map_r - ignored +=1; - } - else{ + ignored += 1; + } else { // Make sure we modify the previous entry - self.added_tokens_map.entry(token.content.clone()).and_modify(|old_id| *old_id = id).or_insert_with(||id); + self.added_tokens_map + .entry(token.content.clone()) + .and_modify(|old_id| *old_id = id) + .or_insert_with(|| id); if !self.special_tokens_set.contains(&token.content) { self.added_tokens.push(token.clone()); } @@ -284,7 +286,6 @@ impl AddedVocabulary { .and_modify(|t| *t = token.clone()) .or_insert_with(|| token.clone()); } - } self.refresh_added_tokens(model, normalizer); @@ -603,8 +604,11 @@ mod tests { assert_eq!(vocab.len(), 2); // Also adds tokens already covered by the model - let added_token= AddedToken::from("test", false); - assert_eq!(vocab.add_tokens(&[added_token.clone()], &model, normalizer),1); + let added_token = AddedToken::from("test", false); + assert_eq!( + vocab.add_tokens(&[added_token.clone()], &model, normalizer), + 1 + ); assert_eq!(vocab.len(), 3); assert_eq!(vocab.get_vocab_r()[&0], added_token); @@ -647,11 +651,14 @@ mod tests { ); assert_eq!(vocab.len(), 3); // New token was added assert!(vocab.is_special_token("test")); - assert_eq!(*vocab.get_vocab_r(), HashMap::from([ - (2,AddedToken::from("added_token_1", true)), - (3,AddedToken::from("added_token_2", true)), - (0,AddedToken::from("test", true)), - ])); + assert_eq!( + *vocab.get_vocab_r(), + HashMap::from([ + (2, AddedToken::from("added_token_1", true)), + (3, AddedToken::from("added_token_2", true)), + (0, AddedToken::from("test", true)), + ]) + ); assert!(vocab.added_tokens_map.contains_key("test")); assert!(vocab.added_tokens_map_r.contains_key(&0)); }