Option to skip special tokens while decoding

This commit is contained in:
Anthony MOI
2019-12-19 20:03:02 -05:00
parent a8d68d516d
commit b7040e0412
3 changed files with 22 additions and 10 deletions

View File

@ -192,12 +192,16 @@ impl Tokenizer {
.into() .into()
} }
fn decode(&self, ids: Vec<u32>) -> PyResult<String> { fn decode(&self, ids: Vec<u32>, skip_special_tokens: bool) -> PyResult<String> {
ToPyResult(self.tokenizer.decode(ids)).into() ToPyResult(self.tokenizer.decode(ids, skip_special_tokens)).into()
} }
fn decode_batch(&self, sentences: Vec<Vec<u32>>) -> PyResult<Vec<String>> { fn decode_batch(
ToPyResult(self.tokenizer.decode_batch(sentences)).into() &self,
sentences: Vec<Vec<u32>>,
skip_special_tokens: bool,
) -> PyResult<Vec<String>> {
ToPyResult(self.tokenizer.decode_batch(sentences, skip_special_tokens)).into()
} }
fn token_to_id(&self, token: &str) -> Option<u32> { fn token_to_id(&self, token: &str) -> Option<u32> {

View File

@ -53,7 +53,7 @@ fn shell(matches: &ArgMatches) -> Result<()> {
println!("Offsets:\t{:?}", encoded.get_offsets()); println!("Offsets:\t{:?}", encoded.get_offsets());
println!( println!(
"Decoded:\t{}", "Decoded:\t{}",
tokenizer.decode(encoded.get_ids().to_vec()).unwrap() tokenizer.decode(encoded.get_ids().to_vec(), true).unwrap()
); );
println!("Tokenized in {:?}", elapsed); println!("Tokenized in {:?}", elapsed);
} }

View File

@ -342,15 +342,19 @@ impl Tokenizer {
} }
/// Decode the given ids, back to a String /// Decode the given ids, back to a String
pub fn decode(&self, ids: Vec<u32>) -> Result<String> { pub fn decode(&self, ids: Vec<u32>, skip_special_tokens: bool) -> Result<String> {
let tokens = ids let tokens = ids
.into_iter() .into_iter()
.map(|id| { .map(|id| {
if let Some(token) = self.added_tokens_r.get(&id) { let token = if let Some(token) = self.added_tokens_r.get(&id) {
Some(token.content.to_owned()) Some(token.content.to_owned())
} else { } else {
self.model.id_to_token(id) self.model.id_to_token(id)
} };
token.filter(|token| {
!skip_special_tokens || !self.special_tokens.contains_key(token)
})
}) })
.filter(|token| token.is_some()) .filter(|token| token.is_some())
.map(|id| id.unwrap()) .map(|id| id.unwrap())
@ -364,10 +368,14 @@ impl Tokenizer {
} }
/// Decode all sentences in parallel /// Decode all sentences in parallel
pub fn decode_batch(&self, sentences: Vec<Vec<u32>>) -> Result<Vec<String>> { pub fn decode_batch(
&self,
sentences: Vec<Vec<u32>>,
skip_special_tokens: bool,
) -> Result<Vec<String>> {
sentences sentences
.into_par_iter() .into_par_iter()
.map(|sentence| self.decode(sentence)) .map(|sentence| self.decode(sentence, skip_special_tokens))
.collect() .collect()
} }