diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index 3b9b6cac..8a25d64c 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -192,12 +192,16 @@ impl Tokenizer { .into() } - fn decode(&self, ids: Vec) -> PyResult { - ToPyResult(self.tokenizer.decode(ids)).into() + fn decode(&self, ids: Vec, skip_special_tokens: bool) -> PyResult { + ToPyResult(self.tokenizer.decode(ids, skip_special_tokens)).into() } - fn decode_batch(&self, sentences: Vec>) -> PyResult> { - ToPyResult(self.tokenizer.decode_batch(sentences)).into() + fn decode_batch( + &self, + sentences: Vec>, + skip_special_tokens: bool, + ) -> PyResult> { + ToPyResult(self.tokenizer.decode_batch(sentences, skip_special_tokens)).into() } fn token_to_id(&self, token: &str) -> Option { diff --git a/tokenizers/src/cli.rs b/tokenizers/src/cli.rs index c4fed6db..1a91aca0 100644 --- a/tokenizers/src/cli.rs +++ b/tokenizers/src/cli.rs @@ -53,7 +53,7 @@ fn shell(matches: &ArgMatches) -> Result<()> { println!("Offsets:\t{:?}", encoded.get_offsets()); println!( "Decoded:\t{}", - tokenizer.decode(encoded.get_ids().to_vec()).unwrap() + tokenizer.decode(encoded.get_ids().to_vec(), true).unwrap() ); println!("Tokenized in {:?}", elapsed); } diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index 9603eebd..ee8db5b2 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -342,15 +342,19 @@ impl Tokenizer { } /// Decode the given ids, back to a String - pub fn decode(&self, ids: Vec) -> Result { + pub fn decode(&self, ids: Vec, skip_special_tokens: bool) -> Result { let tokens = ids .into_iter() .map(|id| { - if let Some(token) = self.added_tokens_r.get(&id) { + let token = if let Some(token) = self.added_tokens_r.get(&id) { Some(token.content.to_owned()) } else { self.model.id_to_token(id) - } + }; + + token.filter(|token| { + !skip_special_tokens || !self.special_tokens.contains_key(token) + }) }) .filter(|token| token.is_some()) .map(|id| id.unwrap()) @@ -364,10 +368,14 @@ impl Tokenizer { } /// Decode all sentences in parallel - pub fn decode_batch(&self, sentences: Vec>) -> Result> { + pub fn decode_batch( + &self, + sentences: Vec>, + skip_special_tokens: bool, + ) -> Result> { sentences .into_par_iter() - .map(|sentence| self.decode(sentence)) + .map(|sentence| self.decode(sentence, skip_special_tokens)) .collect() }