diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs index 1b7dd314..6053ee72 100644 --- a/tokenizers/src/tokenizer/added_vocabulary.rs +++ b/tokenizers/src/tokenizer/added_vocabulary.rs @@ -255,36 +255,36 @@ impl AddedVocabulary { // Then we delegate to `add_tokens`, that will take care of refreshing added tokens too. let mut ignored = 0; for token in tokens { - if token.content.is_empty() { + if token.content.is_empty() || self.added_tokens_map_r.values().any(|val| val == token) + { ignored += 1; continue; } - // If a token is already part of the vocabulary, we mark it as added - let id = if let Some(id) = self.token_to_id(&token.content, model) { - id - } else { - let new_id = (model.get_vocab_size() + cmp::max(self.added_tokens_map_r.keys(),0)) as u32; + let new_id = if let Some(new_id) = self.token_to_id(&token.content, model) { new_id - }; - - if self.added_tokens_map_r.values().any(|val| val == token) { - // We only ignore if the AddedToken is already part of the added_tokens_map_r - ignored += 1; } else { - // Make sure we modify the previous entry self.added_tokens_map - .entry(token.content.clone()) - .and_modify(|old_id| *old_id = id) - .or_insert_with(|| id); - if !self.special_tokens_set.contains(&token.content) { - self.added_tokens.push(token.clone()); - } - // Update the current revert operation - self.added_tokens_map_r - .entry(id) - .and_modify(|t| *t = token.clone()) - .or_insert_with(|| token.clone()); + .values() + .cloned() + .max() + .map_or(model.get_vocab_size() as u32, |max| max.clone() + 1) + }; + // Make sure we modify the previous entry + self.added_tokens_map + .entry(token.content.clone()) + .and_modify(|old_id| *old_id = new_id) + .or_insert_with(|| new_id); + // Update the current revert operation + self.added_tokens_map_r + .entry(new_id) + .and_modify(|t| *t = token.clone()) + .or_insert_with(|| token.clone()); + // Make sure to remove previous entry (if the token gets a new id) + + // Finally add the token to the classic set if special + if !self.special_tokens_set.contains(&token.content) { + self.added_tokens.push(token.clone()); } } @@ -654,13 +654,24 @@ mod tests { assert_eq!( *vocab.get_vocab_r(), HashMap::from([ + (0, AddedToken::from("test", true)), (2, AddedToken::from("added_token_1", true)), (3, AddedToken::from("added_token_2", true)), - (0, AddedToken::from("test", true)), ]) ); assert!(vocab.added_tokens_map.contains_key("test")); assert!(vocab.added_tokens_map_r.contains_key(&0)); + + vocab.add_tokens( + &[ + AddedToken::from("tost", true), + AddedToken::from("another_two", true), + ], + &model, + normalizer, + ); + assert_eq!(vocab.len(), 5); // New token was added + assert_eq!(vocab.get_vocab()["another_two"], 4); // New token was added, but the index is not the length of the vocab } #[test]