mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
updates
This commit is contained in:
@ -89,7 +89,6 @@ from .tokenizers import (
|
||||
pre_tokenizers,
|
||||
processors,
|
||||
trainers,
|
||||
__version__,
|
||||
)
|
||||
from .implementations import (
|
||||
BertWordPieceTokenizer,
|
||||
|
@ -66,6 +66,12 @@ impl AddedToken {
|
||||
self.normalized = normalized;
|
||||
self
|
||||
}
|
||||
/// Specify whether this token is special, meaning if it should be skipped when decoding
|
||||
#[must_use]
|
||||
pub fn special(mut self, special: bool) -> Self {
|
||||
self.special = special;
|
||||
self
|
||||
}
|
||||
}
|
||||
impl Default for AddedToken {
|
||||
fn default() -> Self {
|
||||
@ -192,6 +198,11 @@ impl AddedVocabulary {
|
||||
&self.added_tokens_map
|
||||
}
|
||||
|
||||
/// Get the additional vocabulary with the AddedTokens
|
||||
pub fn get_vocab_r(&self) -> &HashMap<u32, AddedToken> {
|
||||
&self.added_tokens_map_r
|
||||
}
|
||||
|
||||
/// Get the id matching one of our token if it exists
|
||||
pub fn token_to_id(&self, token: &str, model: &impl Model) -> Option<u32> {
|
||||
self.added_tokens_map
|
||||
@ -248,26 +259,32 @@ impl AddedVocabulary {
|
||||
ignored += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
|
||||
// If a token is already part of the vocabulary, we mark it as added
|
||||
let id = if let Some(id) = self.token_to_id(&token.content, model) {
|
||||
ignored += 1;
|
||||
id
|
||||
} else {
|
||||
let new_id = (model.get_vocab_size() + self.added_tokens_map.len()) as u32;
|
||||
self.added_tokens_map.insert(token.content.clone(), new_id);
|
||||
|
||||
new_id
|
||||
};
|
||||
|
||||
if self.added_tokens_map_r.values().any(| val| val == token){
|
||||
// We only ignore if the AddedToken is already part of the added_tokens_map_r
|
||||
ignored +=1;
|
||||
}
|
||||
else{
|
||||
// Make sure we modify the previous entry
|
||||
self.added_tokens_map.entry(token.content.clone()).and_modify(|old_id| *old_id = id).or_insert_with(||id);
|
||||
if !self.special_tokens_set.contains(&token.content) {
|
||||
self.added_tokens.push(token.clone());
|
||||
}
|
||||
// Update the current revert operation
|
||||
self.added_tokens_map_r
|
||||
.entry(id)
|
||||
.and_modify(|t| *t = token.clone())
|
||||
.or_insert_with(|| token.clone());
|
||||
}
|
||||
|
||||
new_id
|
||||
};
|
||||
|
||||
// Update the current revert operation
|
||||
self.added_tokens_map_r
|
||||
.entry(id)
|
||||
.and_modify(|t| *t = token.clone())
|
||||
.or_insert_with(|| token.clone());
|
||||
}
|
||||
|
||||
self.refresh_added_tokens(model, normalizer);
|
||||
@ -585,12 +602,12 @@ mod tests {
|
||||
);
|
||||
assert_eq!(vocab.len(), 2);
|
||||
|
||||
// Does not add tokens already covered by the model
|
||||
assert_eq!(
|
||||
vocab.add_tokens(&[AddedToken::from("test", false)], &model, normalizer),
|
||||
0
|
||||
);
|
||||
assert_eq!(vocab.len(), 2);
|
||||
// Also adds tokens already covered by the model
|
||||
let added_token= AddedToken::from("test", false);
|
||||
assert_eq!(vocab.add_tokens(&[added_token.clone()], &model, normalizer),1);
|
||||
assert_eq!(vocab.len(), 3);
|
||||
|
||||
assert_eq!(vocab.get_vocab_r()[&1], added_token);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -626,10 +643,16 @@ mod tests {
|
||||
// Can add tokens already covered by the model
|
||||
assert_eq!(
|
||||
vocab.add_special_tokens(&[AddedToken::from("test", true)], &model, normalizer),
|
||||
0
|
||||
1
|
||||
);
|
||||
assert_eq!(vocab.len(), 2); // Did not add a new token, since it exist in the original model
|
||||
assert_eq!(vocab.len(), 3); // New token was added
|
||||
assert!(vocab.is_special_token("test"));
|
||||
assert_eq!(*vocab.get_vocab_r(), HashMap::from([
|
||||
(2,AddedToken::from("added_token_1", true)),
|
||||
(3,AddedToken::from("added_token_2", true)),
|
||||
(4,AddedToken::from("test", true)),
|
||||
]));
|
||||
|
||||
assert!(!vocab.added_tokens_map.contains_key("test"));
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user