Ability to decode with added tokens

This commit is contained in:
Anthony MOI
2019-12-16 18:22:46 -05:00
parent 4c7f6e1f04
commit f92e73b8f3
3 changed files with 12 additions and 20 deletions

View File

@ -184,15 +184,6 @@ impl Model for BPE {
Ok(encoded) Ok(encoded)
} }
fn decode(&self, ids: Vec<u32>) -> Result<Vec<String>> {
Ok(ids
.into_iter()
.map(|id| self.vocab_r.get(&id))
.filter(|token| token.is_some())
.map(|id| id.unwrap().clone())
.collect())
}
fn token_to_id(&self, token: &str) -> Option<u32> { fn token_to_id(&self, token: &str) -> Option<u32> {
self.vocab.get(token).copied() self.vocab.get(token).copied()
} }

View File

@ -141,15 +141,6 @@ impl Model for WordPiece {
Ok(output_tokens) Ok(output_tokens)
} }
fn decode(&self, ids: Vec<u32>) -> Result<Vec<String>> {
Ok(ids
.into_iter()
.map(|id| self.vocab_r.get(&id))
.filter(|token| token.is_some())
.map(|id| id.unwrap().clone())
.collect())
}
fn token_to_id(&self, token: &str) -> Option<u32> { fn token_to_id(&self, token: &str) -> Option<u32> {
self.vocab.get(token).copied() self.vocab.get(token).copied()
} }

View File

@ -38,7 +38,6 @@ pub trait PreTokenizer {
/// Represents a `Model` used during Tokenization (Like BPE or Word or Unigram) /// Represents a `Model` used during Tokenization (Like BPE or Word or Unigram)
pub trait Model { pub trait Model {
fn tokenize(&self, tokens: Vec<String>) -> Result<Vec<Token>>; fn tokenize(&self, tokens: Vec<String>) -> Result<Vec<Token>>;
fn decode(&self, ids: Vec<u32>) -> Result<Vec<String>>;
fn token_to_id(&self, token: &str) -> Option<u32>; fn token_to_id(&self, token: &str) -> Option<u32>;
fn id_to_token(&self, id: u32) -> Option<String>; fn id_to_token(&self, id: u32) -> Option<String>;
fn get_vocab_size(&self) -> usize; fn get_vocab_size(&self) -> usize;
@ -292,7 +291,18 @@ impl Tokenizer {
/// Decode the given ids, back to a String /// Decode the given ids, back to a String
pub fn decode(&self, ids: Vec<u32>) -> Result<String> { pub fn decode(&self, ids: Vec<u32>) -> Result<String> {
let tokens = self.model.decode(ids)?; let tokens = ids
.into_iter()
.map(|id| {
if let Some(token) = self.added_tokens_r.get(&id) {
Some(token.content.to_owned())
} else {
self.model.id_to_token(id)
}
})
.filter(|token| token.is_some())
.map(|id| id.unwrap())
.collect::<Vec<_>>();
if let Some(decoder) = &self.decoder { if let Some(decoder) = &self.decoder {
decoder.decode(tokens) decoder.decode(tokens)