From b7040e04126acfebcdd76a14fb96fcae70396167 Mon Sep 17 00:00:00 2001 From: Anthony MOI Date: Thu, 19 Dec 2019 20:03:02 -0500 Subject: [PATCH] Option to skip special tokens while decoding --- bindings/python/src/tokenizer.rs | 12 ++++++++---- tokenizers/src/cli.rs | 2 +- tokenizers/src/tokenizer/mod.rs | 18 +++++++++++++----- 3 files changed, 22 insertions(+), 10 deletions(-) diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index 3b9b6cac..8a25d64c 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -192,12 +192,16 @@ impl Tokenizer { .into() } - fn decode(&self, ids: Vec) -> PyResult { - ToPyResult(self.tokenizer.decode(ids)).into() + fn decode(&self, ids: Vec, skip_special_tokens: bool) -> PyResult { + ToPyResult(self.tokenizer.decode(ids, skip_special_tokens)).into() } - fn decode_batch(&self, sentences: Vec>) -> PyResult> { - ToPyResult(self.tokenizer.decode_batch(sentences)).into() + fn decode_batch( + &self, + sentences: Vec>, + skip_special_tokens: bool, + ) -> PyResult> { + ToPyResult(self.tokenizer.decode_batch(sentences, skip_special_tokens)).into() } fn token_to_id(&self, token: &str) -> Option { diff --git a/tokenizers/src/cli.rs b/tokenizers/src/cli.rs index c4fed6db..1a91aca0 100644 --- a/tokenizers/src/cli.rs +++ b/tokenizers/src/cli.rs @@ -53,7 +53,7 @@ fn shell(matches: &ArgMatches) -> Result<()> { println!("Offsets:\t{:?}", encoded.get_offsets()); println!( "Decoded:\t{}", - tokenizer.decode(encoded.get_ids().to_vec()).unwrap() + tokenizer.decode(encoded.get_ids().to_vec(), true).unwrap() ); println!("Tokenized in {:?}", elapsed); } diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index 9603eebd..ee8db5b2 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -342,15 +342,19 @@ impl Tokenizer { } /// Decode the given ids, back to a String - pub fn decode(&self, ids: Vec) -> Result { + pub fn decode(&self, ids: Vec, skip_special_tokens: bool) -> Result { let tokens = ids .into_iter() .map(|id| { - if let Some(token) = self.added_tokens_r.get(&id) { + let token = if let Some(token) = self.added_tokens_r.get(&id) { Some(token.content.to_owned()) } else { self.model.id_to_token(id) - } + }; + + token.filter(|token| { + !skip_special_tokens || !self.special_tokens.contains_key(token) + }) }) .filter(|token| token.is_some()) .map(|id| id.unwrap()) @@ -364,10 +368,14 @@ impl Tokenizer { } /// Decode all sentences in parallel - pub fn decode_batch(&self, sentences: Vec>) -> Result> { + pub fn decode_batch( + &self, + sentences: Vec>, + skip_special_tokens: bool, + ) -> Result> { sentences .into_par_iter() - .map(|sentence| self.decode(sentence)) + .map(|sentence| self.decode(sentence, skip_special_tokens)) .collect() }