Include the added tokens while converting tokens and ids

This commit is contained in:
Anthony MOI
2019-12-19 18:32:37 -05:00
parent 076ba297fb
commit 7f032b62df

View File

@ -93,6 +93,14 @@ pub struct AddedToken {
/// Whether this token must be a single word or can break words /// Whether this token must be a single word or can break words
pub single_word: bool, pub single_word: bool,
} }
impl AddedToken {
fn from(content: String) -> Self {
AddedToken {
content,
..Default::default()
}
}
}
impl Default for AddedToken { impl Default for AddedToken {
fn default() -> Self { fn default() -> Self {
AddedToken { AddedToken {
@ -206,13 +214,21 @@ impl Tokenizer {
/// Converts a token in the corresponding id. /// Converts a token in the corresponding id.
pub fn token_to_id(&self, token: &str) -> Option<u32> { pub fn token_to_id(&self, token: &str) -> Option<u32> {
if let Some(id) = self.added_tokens.get(&AddedToken::from(token.to_owned())) {
Some(*id)
} else {
self.model.token_to_id(token) self.model.token_to_id(token)
} }
}
/// Converts an id to the corresponding token. /// Converts an id to the corresponding token.
pub fn id_to_token(&self, id: u32) -> Option<String> { pub fn id_to_token(&self, id: u32) -> Option<String> {
if let Some(token) = self.added_tokens_r.get(&id) {
Some(token.content.clone())
} else {
self.model.id_to_token(id) self.model.id_to_token(id)
} }
}
/// Encode the given sentence /// Encode the given sentence
pub fn encode(&self, input: EncodeInput) -> Result<Encoding> { pub fn encode(&self, input: EncodeInput) -> Result<Encoding> {