Fix panic in DecodeStream::step due to incorrect index usage (#1699)

* Add a failing test for step_decode_stream

* Improve test case for test_decode_stream_step_no_panic

* Fix subtract with overflow issue in step_decode_stream
This commit is contained in:
Sungyoon Jeong
2025-01-09 21:24:04 +09:00
committed by GitHub
parent c04b97aab1
commit 862d1a346a

View File

@ -1106,7 +1106,7 @@ where
}
let new_text = &string[prefix.len()..].to_string();
let new_prefix_index = ids.len() - *prefix_index;
*ids = ids.drain(*read_index..).collect();
*ids = ids.drain(*prefix_index..).collect();
*prefix = tokenizer.decode(ids, skip_special_tokens)?;
*read_index = *prefix_index;
*prefix_index = new_prefix_index;
@ -1616,4 +1616,59 @@ mod test {
let decoded = tokenizer.decode(encoded.get_ids(), false);
assert_eq!(decoded.unwrap(), "Hey! how is this token: д")
}
#[cfg(feature = "http")]
#[test]
fn test_decode_stream_step_no_panic() {
use std::panic;
use crate::Tokenizer;
let tokenizer = Tokenizer::from_pretrained("meta-llama/Meta-Llama-3-8B", None).unwrap();
// "A B C D E F G H I J"
let mut decode_stream = tokenizer.decode_stream(false);
let output_tokens = vec![32, 426, 356, 423, 469, 435, 480, 473, 358, 622];
let expected_outputs = vec![
Some("A".to_string()),
Some(" B".to_string()),
Some(" C".to_string()),
Some(" D".to_string()),
Some(" E".to_string()),
Some(" F".to_string()),
Some(" G".to_string()),
Some(" H".to_string()),
Some(" I".to_string()),
Some(" J".to_string()),
];
for (i, &token) in output_tokens.iter().enumerate() {
let maybe_panic =
panic::catch_unwind(panic::AssertUnwindSafe(|| decode_stream.step(token)));
assert!(maybe_panic.is_ok());
let result = maybe_panic.unwrap();
assert!(result.is_ok());
assert_eq!(result.unwrap(), expected_outputs[i]);
}
// "삥뽕빵" (Korean words composed of 2-3 tokens: [80690, 98], [167, 121, 243], and [102457, 113])
let mut decode_stream = tokenizer.decode_stream(false);
let output_tokens = vec![80690, 98, 167, 121, 243, 102457, 113];
let expected_outputs = vec![
None,
Some("".to_string()),
None,
None,
Some("".to_string()),
None,
Some("".to_string()),
];
for (i, &token) in output_tokens.iter().enumerate() {
let maybe_panic =
panic::catch_unwind(panic::AssertUnwindSafe(|| decode_stream.step(token)));
assert!(maybe_panic.is_ok());
let result = maybe_panic.unwrap();
assert!(result.is_ok());
assert_eq!(result.unwrap(), expected_outputs[i]);
}
}
}