fix and update tes

This commit is contained in:
Arthur Zucker
2023-09-01 20:40:06 +00:00
parent 2b72017e17
commit 399c6fe852

View File

@ -255,36 +255,36 @@ impl AddedVocabulary {
// Then we delegate to `add_tokens`, that will take care of refreshing added tokens too. // Then we delegate to `add_tokens`, that will take care of refreshing added tokens too.
let mut ignored = 0; let mut ignored = 0;
for token in tokens { for token in tokens {
if token.content.is_empty() { if token.content.is_empty() || self.added_tokens_map_r.values().any(|val| val == token)
{
ignored += 1; ignored += 1;
continue; continue;
} }
// If a token is already part of the vocabulary, we mark it as added // If a token is already part of the vocabulary, we mark it as added
let id = if let Some(id) = self.token_to_id(&token.content, model) { let new_id = if let Some(new_id) = self.token_to_id(&token.content, model) {
id
} else {
let new_id = (model.get_vocab_size() + cmp::max(self.added_tokens_map_r.keys(),0)) as u32;
new_id new_id
};
if self.added_tokens_map_r.values().any(|val| val == token) {
// We only ignore if the AddedToken is already part of the added_tokens_map_r
ignored += 1;
} else { } else {
// Make sure we modify the previous entry
self.added_tokens_map self.added_tokens_map
.entry(token.content.clone()) .values()
.and_modify(|old_id| *old_id = id) .cloned()
.or_insert_with(|| id); .max()
if !self.special_tokens_set.contains(&token.content) { .map_or(model.get_vocab_size() as u32, |max| max.clone() + 1)
self.added_tokens.push(token.clone()); };
} // Make sure we modify the previous entry
// Update the current revert operation self.added_tokens_map
self.added_tokens_map_r .entry(token.content.clone())
.entry(id) .and_modify(|old_id| *old_id = new_id)
.and_modify(|t| *t = token.clone()) .or_insert_with(|| new_id);
.or_insert_with(|| token.clone()); // Update the current revert operation
self.added_tokens_map_r
.entry(new_id)
.and_modify(|t| *t = token.clone())
.or_insert_with(|| token.clone());
// Make sure to remove previous entry (if the token gets a new id)
// Finally add the token to the classic set if special
if !self.special_tokens_set.contains(&token.content) {
self.added_tokens.push(token.clone());
} }
} }
@ -654,13 +654,24 @@ mod tests {
assert_eq!( assert_eq!(
*vocab.get_vocab_r(), *vocab.get_vocab_r(),
HashMap::from([ HashMap::from([
(0, AddedToken::from("test", true)),
(2, AddedToken::from("added_token_1", true)), (2, AddedToken::from("added_token_1", true)),
(3, AddedToken::from("added_token_2", true)), (3, AddedToken::from("added_token_2", true)),
(0, AddedToken::from("test", true)),
]) ])
); );
assert!(vocab.added_tokens_map.contains_key("test")); assert!(vocab.added_tokens_map.contains_key("test"));
assert!(vocab.added_tokens_map_r.contains_key(&0)); assert!(vocab.added_tokens_map_r.contains_key(&0));
vocab.add_tokens(
&[
AddedToken::from("tost", true),
AddedToken::from("another_two", true),
],
&model,
normalizer,
);
assert_eq!(vocab.len(), 5); // New token was added
assert_eq!(vocab.get_vocab()["another_two"], 4); // New token was added, but the index is not the length of the vocab
} }
#[test] #[test]