From d1566a9ecc4e6f3d67ca5bcb71155f704a694fbc Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 1 Sep 2023 20:48:36 +0000 Subject: [PATCH] update, // AddedTokens can be updated if value changed --- tokenizers/src/tokenizer/added_vocabulary.rs | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs index 6053ee72..92235d55 100644 --- a/tokenizers/src/tokenizer/added_vocabulary.rs +++ b/tokenizers/src/tokenizer/added_vocabulary.rs @@ -85,8 +85,7 @@ impl Default for AddedToken { } } } -// We only want to hash on the content. AddedToken cannot be added multiple times with different -// options +// AddedTokens can be updated if value changed impl std::hash::Hash for AddedToken { fn hash(&self, state: &mut H) { self.content.hash(state); @@ -94,7 +93,7 @@ impl std::hash::Hash for AddedToken { } impl std::cmp::PartialEq for AddedToken { fn eq(&self, other: &Self) -> bool { - self.content == other.content + self.content == other.content && self.special == other.special && self.lstrip == other.lstrip && self.rstrip == other.rstrip && self.normalized == other.normalized && self.single_word == other.single_word } } impl std::cmp::Eq for AddedToken {} @@ -665,13 +664,18 @@ mod tests { vocab.add_tokens( &[ AddedToken::from("tost", true), - AddedToken::from("another_two", true), + AddedToken::from("another_two", false), ], &model, normalizer, ); assert_eq!(vocab.len(), 5); // New token was added assert_eq!(vocab.get_vocab()["another_two"], 4); // New token was added, but the index is not the length of the vocab + + // Let's add an already added token again + assert_eq!(vocab.add_special_tokens(&[AddedToken::from("another_two", true)], &model, normalizer), 1); + assert_eq!(vocab.len(), 5); // Token was already there + assert_eq!(vocab.get_vocab()["another_two"], 4); // Token idx not changed } #[test]