This commit is contained in:
Arthur Zucker
2023-09-01 18:41:36 +00:00
parent 8e522a38d9
commit 345b4eba96
2 changed files with 43 additions and 21 deletions

View File

@ -89,7 +89,6 @@ from .tokenizers import (
pre_tokenizers,
processors,
trainers,
__version__,
)
from .implementations import (
BertWordPieceTokenizer,

View File

@ -66,6 +66,12 @@ impl AddedToken {
self.normalized = normalized;
self
}
/// Specify whether this token is special, meaning if it should be skipped when decoding
#[must_use]
pub fn special(mut self, special: bool) -> Self {
self.special = special;
self
}
}
impl Default for AddedToken {
fn default() -> Self {
@ -192,6 +198,11 @@ impl AddedVocabulary {
&self.added_tokens_map
}
/// Get the additional vocabulary with the AddedTokens
pub fn get_vocab_r(&self) -> &HashMap<u32, AddedToken> {
&self.added_tokens_map_r
}
/// Get the id matching one of our token if it exists
pub fn token_to_id(&self, token: &str, model: &impl Model) -> Option<u32> {
self.added_tokens_map
@ -248,26 +259,32 @@ impl AddedVocabulary {
ignored += 1;
continue;
}
// If a token is already part of the vocabulary, we mark it as added
let id = if let Some(id) = self.token_to_id(&token.content, model) {
ignored += 1;
id
} else {
let new_id = (model.get_vocab_size() + self.added_tokens_map.len()) as u32;
self.added_tokens_map.insert(token.content.clone(), new_id);
new_id
};
if self.added_tokens_map_r.values().any(| val| val == token){
// We only ignore if the AddedToken is already part of the added_tokens_map_r
ignored +=1;
}
else{
// Make sure we modify the previous entry
self.added_tokens_map.entry(token.content.clone()).and_modify(|old_id| *old_id = id).or_insert_with(||id);
if !self.special_tokens_set.contains(&token.content) {
self.added_tokens.push(token.clone());
}
// Update the current revert operation
self.added_tokens_map_r
.entry(id)
.and_modify(|t| *t = token.clone())
.or_insert_with(|| token.clone());
}
new_id
};
// Update the current revert operation
self.added_tokens_map_r
.entry(id)
.and_modify(|t| *t = token.clone())
.or_insert_with(|| token.clone());
}
self.refresh_added_tokens(model, normalizer);
@ -585,12 +602,12 @@ mod tests {
);
assert_eq!(vocab.len(), 2);
// Does not add tokens already covered by the model
assert_eq!(
vocab.add_tokens(&[AddedToken::from("test", false)], &model, normalizer),
0
);
assert_eq!(vocab.len(), 2);
// Also adds tokens already covered by the model
let added_token= AddedToken::from("test", false);
assert_eq!(vocab.add_tokens(&[added_token.clone()], &model, normalizer),1);
assert_eq!(vocab.len(), 3);
assert_eq!(vocab.get_vocab_r()[&1], added_token);
}
#[test]
@ -626,10 +643,16 @@ mod tests {
// Can add tokens already covered by the model
assert_eq!(
vocab.add_special_tokens(&[AddedToken::from("test", true)], &model, normalizer),
0
1
);
assert_eq!(vocab.len(), 2); // Did not add a new token, since it exist in the original model
assert_eq!(vocab.len(), 3); // New token was added
assert!(vocab.is_special_token("test"));
assert_eq!(*vocab.get_vocab_r(), HashMap::from([
(2,AddedToken::from("added_token_1", true)),
(3,AddedToken::from("added_token_2", true)),
(4,AddedToken::from("test", true)),
]));
assert!(!vocab.added_tokens_map.contains_key("test"));
}