update, // AddedTokens can be updated if value changed

This commit is contained in:
Arthur Zucker
2023-09-01 20:48:36 +00:00
parent 399c6fe852
commit d1566a9ecc

View File

@ -85,8 +85,7 @@ impl Default for AddedToken {
} }
} }
} }
// We only want to hash on the content. AddedToken cannot be added multiple times with different // AddedTokens can be updated if value changed
// options
impl std::hash::Hash for AddedToken { impl std::hash::Hash for AddedToken {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) { fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.content.hash(state); self.content.hash(state);
@ -94,7 +93,7 @@ impl std::hash::Hash for AddedToken {
} }
impl std::cmp::PartialEq for AddedToken { impl std::cmp::PartialEq for AddedToken {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
self.content == other.content self.content == other.content && self.special == other.special && self.lstrip == other.lstrip && self.rstrip == other.rstrip && self.normalized == other.normalized && self.single_word == other.single_word
} }
} }
impl std::cmp::Eq for AddedToken {} impl std::cmp::Eq for AddedToken {}
@ -665,13 +664,18 @@ mod tests {
vocab.add_tokens( vocab.add_tokens(
&[ &[
AddedToken::from("tost", true), AddedToken::from("tost", true),
AddedToken::from("another_two", true), AddedToken::from("another_two", false),
], ],
&model, &model,
normalizer, normalizer,
); );
assert_eq!(vocab.len(), 5); // New token was added assert_eq!(vocab.len(), 5); // New token was added
assert_eq!(vocab.get_vocab()["another_two"], 4); // New token was added, but the index is not the length of the vocab assert_eq!(vocab.get_vocab()["another_two"], 4); // New token was added, but the index is not the length of the vocab
// Let's add an already added token again
assert_eq!(vocab.add_special_tokens(&[AddedToken::from("another_two", true)], &model, normalizer), 1);
assert_eq!(vocab.len(), 5); // Token was already there
assert_eq!(vocab.get_vocab()["another_two"], 4); // Token idx not changed
} }
#[test] #[test]