This commit is contained in:
Arthur Zucker
2023-09-01 18:57:39 +00:00
parent 2dca476810
commit db319492f7

View File

@ -259,7 +259,7 @@ impl AddedVocabulary {
ignored += 1; ignored += 1;
continue; continue;
} }
// If a token is already part of the vocabulary, we mark it as added // If a token is already part of the vocabulary, we mark it as added
let id = if let Some(id) = self.token_to_id(&token.content, model) { let id = if let Some(id) = self.token_to_id(&token.content, model) {
id id
@ -267,14 +267,16 @@ impl AddedVocabulary {
let new_id = (model.get_vocab_size() + self.added_tokens_map.len()) as u32; let new_id = (model.get_vocab_size() + self.added_tokens_map.len()) as u32;
new_id new_id
}; };
if self.added_tokens_map_r.values().any(| val| val == token){ if self.added_tokens_map_r.values().any(|val| val == token) {
// We only ignore if the AddedToken is already part of the added_tokens_map_r // We only ignore if the AddedToken is already part of the added_tokens_map_r
ignored +=1; ignored += 1;
} } else {
else{
// Make sure we modify the previous entry // Make sure we modify the previous entry
self.added_tokens_map.entry(token.content.clone()).and_modify(|old_id| *old_id = id).or_insert_with(||id); self.added_tokens_map
.entry(token.content.clone())
.and_modify(|old_id| *old_id = id)
.or_insert_with(|| id);
if !self.special_tokens_set.contains(&token.content) { if !self.special_tokens_set.contains(&token.content) {
self.added_tokens.push(token.clone()); self.added_tokens.push(token.clone());
} }
@ -284,7 +286,6 @@ impl AddedVocabulary {
.and_modify(|t| *t = token.clone()) .and_modify(|t| *t = token.clone())
.or_insert_with(|| token.clone()); .or_insert_with(|| token.clone());
} }
} }
self.refresh_added_tokens(model, normalizer); self.refresh_added_tokens(model, normalizer);
@ -603,8 +604,11 @@ mod tests {
assert_eq!(vocab.len(), 2); assert_eq!(vocab.len(), 2);
// Also adds tokens already covered by the model // Also adds tokens already covered by the model
let added_token= AddedToken::from("test", false); let added_token = AddedToken::from("test", false);
assert_eq!(vocab.add_tokens(&[added_token.clone()], &model, normalizer),1); assert_eq!(
vocab.add_tokens(&[added_token.clone()], &model, normalizer),
1
);
assert_eq!(vocab.len(), 3); assert_eq!(vocab.len(), 3);
assert_eq!(vocab.get_vocab_r()[&0], added_token); assert_eq!(vocab.get_vocab_r()[&0], added_token);
@ -647,11 +651,14 @@ mod tests {
); );
assert_eq!(vocab.len(), 3); // New token was added assert_eq!(vocab.len(), 3); // New token was added
assert!(vocab.is_special_token("test")); assert!(vocab.is_special_token("test"));
assert_eq!(*vocab.get_vocab_r(), HashMap::from([ assert_eq!(
(2,AddedToken::from("added_token_1", true)), *vocab.get_vocab_r(),
(3,AddedToken::from("added_token_2", true)), HashMap::from([
(0,AddedToken::from("test", true)), (2, AddedToken::from("added_token_1", true)),
])); (3, AddedToken::from("added_token_2", true)),
(0, AddedToken::from("test", true)),
])
);
assert!(vocab.added_tokens_map.contains_key("test")); assert!(vocab.added_tokens_map.contains_key("test"));
assert!(vocab.added_tokens_map_r.contains_key(&0)); assert!(vocab.added_tokens_map_r.contains_key(&0));
} }