diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index 893d2743..b6768110 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -1106,7 +1106,7 @@ where } let new_text = &string[prefix.len()..].to_string(); let new_prefix_index = ids.len() - *prefix_index; - *ids = ids.drain(*read_index..).collect(); + *ids = ids.drain(*prefix_index..).collect(); *prefix = tokenizer.decode(ids, skip_special_tokens)?; *read_index = *prefix_index; *prefix_index = new_prefix_index; @@ -1616,4 +1616,59 @@ mod test { let decoded = tokenizer.decode(encoded.get_ids(), false); assert_eq!(decoded.unwrap(), "Hey! how is this token: д") } + + #[cfg(feature = "http")] + #[test] + fn test_decode_stream_step_no_panic() { + use std::panic; + + use crate::Tokenizer; + + let tokenizer = Tokenizer::from_pretrained("meta-llama/Meta-Llama-3-8B", None).unwrap(); + + // "A B C D E F G H I J" + let mut decode_stream = tokenizer.decode_stream(false); + let output_tokens = vec![32, 426, 356, 423, 469, 435, 480, 473, 358, 622]; + let expected_outputs = vec![ + Some("A".to_string()), + Some(" B".to_string()), + Some(" C".to_string()), + Some(" D".to_string()), + Some(" E".to_string()), + Some(" F".to_string()), + Some(" G".to_string()), + Some(" H".to_string()), + Some(" I".to_string()), + Some(" J".to_string()), + ]; + for (i, &token) in output_tokens.iter().enumerate() { + let maybe_panic = + panic::catch_unwind(panic::AssertUnwindSafe(|| decode_stream.step(token))); + assert!(maybe_panic.is_ok()); + let result = maybe_panic.unwrap(); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), expected_outputs[i]); + } + + // "삥뽕빵" (Korean words composed of 2-3 tokens: [80690, 98], [167, 121, 243], and [102457, 113]) + let mut decode_stream = tokenizer.decode_stream(false); + let output_tokens = vec![80690, 98, 167, 121, 243, 102457, 113]; + let expected_outputs = vec![ + None, + Some("삥".to_string()), + None, + None, + Some("뽕".to_string()), + None, + Some("빵".to_string()), + ]; + for (i, &token) in output_tokens.iter().enumerate() { + let maybe_panic = + panic::catch_unwind(panic::AssertUnwindSafe(|| decode_stream.step(token))); + assert!(maybe_panic.is_ok()); + let result = maybe_panic.unwrap(); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), expected_outputs[i]); + } + } }