diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index ebc68dfb..b0836ca3 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -847,35 +847,23 @@ where /// Decode the given ids, back to a String pub fn decode(&self, ids: &[u32], skip_special_tokens: bool) -> Result { - let mut result = String::with_capacity(ids.len()); - let mut chunks = Vec::with_capacity(ids.len()); - for id in ids { - if let Some(added_token) = self.added_vocabulary.simple_id_to_token(*id) { - if skip_special_tokens && self.added_vocabulary.is_special_token(&added_token) { - continue; - } - let text_chunk = if let Some(decoder) = &self.decoder { - decoder.decode(chunks.clone())? - } else { - chunks.join(" ") - }; - result.push_str(&text_chunk); - if !result.is_empty() && self.decoder.is_none() { - result.push(' '); - } - result.push_str(&added_token); - chunks.clear(); - } else if let Some(token) = self.model.id_to_token(*id) { - chunks.push(token); - } - } - let text_chunk = if let Some(decoder) = &self.decoder { - decoder.decode(chunks.clone())? + let tokens = ids + .iter() + .filter_map(|id| { + self.added_vocabulary + .simple_id_to_token(*id) + .or_else(|| self.model.id_to_token(*id)) + .filter(|token| { + !skip_special_tokens || !self.added_vocabulary.is_special_token(token) + }) + }) + .collect::>(); + + if let Some(decoder) = &self.decoder { + decoder.decode(tokens) } else { - chunks.join(" ") - }; - result.push_str(&text_chunk); - Ok(result) + Ok(tokens.join(" ")) + } } }