diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index 44665b67..631d6293 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -93,6 +93,14 @@ pub struct AddedToken { /// Whether this token must be a single word or can break words pub single_word: bool, } +impl AddedToken { + fn from(content: String) -> Self { + AddedToken { + content, + ..Default::default() + } + } +} impl Default for AddedToken { fn default() -> Self { AddedToken { @@ -206,12 +214,20 @@ impl Tokenizer { /// Converts a token in the corresponding id. pub fn token_to_id(&self, token: &str) -> Option { - self.model.token_to_id(token) + if let Some(id) = self.added_tokens.get(&AddedToken::from(token.to_owned())) { + Some(*id) + } else { + self.model.token_to_id(token) + } } /// Converts an id to the corresponding token. pub fn id_to_token(&self, id: u32) -> Option { - self.model.id_to_token(id) + if let Some(token) = self.added_tokens_r.get(&id) { + Some(token.content.clone()) + } else { + self.model.id_to_token(id) + } } /// Encode the given sentence