mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Fix panic in DecodeStream::step due to incorrect index usage (#1699)
* Add a failing test for step_decode_stream * Improve test case for test_decode_stream_step_no_panic * Fix subtract with overflow issue in step_decode_stream
This commit is contained in:
@ -1106,7 +1106,7 @@ where
|
||||
}
|
||||
let new_text = &string[prefix.len()..].to_string();
|
||||
let new_prefix_index = ids.len() - *prefix_index;
|
||||
*ids = ids.drain(*read_index..).collect();
|
||||
*ids = ids.drain(*prefix_index..).collect();
|
||||
*prefix = tokenizer.decode(ids, skip_special_tokens)?;
|
||||
*read_index = *prefix_index;
|
||||
*prefix_index = new_prefix_index;
|
||||
@ -1616,4 +1616,59 @@ mod test {
|
||||
let decoded = tokenizer.decode(encoded.get_ids(), false);
|
||||
assert_eq!(decoded.unwrap(), "Hey! how is this token: д")
|
||||
}
|
||||
|
||||
#[cfg(feature = "http")]
|
||||
#[test]
|
||||
fn test_decode_stream_step_no_panic() {
|
||||
use std::panic;
|
||||
|
||||
use crate::Tokenizer;
|
||||
|
||||
let tokenizer = Tokenizer::from_pretrained("meta-llama/Meta-Llama-3-8B", None).unwrap();
|
||||
|
||||
// "A B C D E F G H I J"
|
||||
let mut decode_stream = tokenizer.decode_stream(false);
|
||||
let output_tokens = vec![32, 426, 356, 423, 469, 435, 480, 473, 358, 622];
|
||||
let expected_outputs = vec![
|
||||
Some("A".to_string()),
|
||||
Some(" B".to_string()),
|
||||
Some(" C".to_string()),
|
||||
Some(" D".to_string()),
|
||||
Some(" E".to_string()),
|
||||
Some(" F".to_string()),
|
||||
Some(" G".to_string()),
|
||||
Some(" H".to_string()),
|
||||
Some(" I".to_string()),
|
||||
Some(" J".to_string()),
|
||||
];
|
||||
for (i, &token) in output_tokens.iter().enumerate() {
|
||||
let maybe_panic =
|
||||
panic::catch_unwind(panic::AssertUnwindSafe(|| decode_stream.step(token)));
|
||||
assert!(maybe_panic.is_ok());
|
||||
let result = maybe_panic.unwrap();
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap(), expected_outputs[i]);
|
||||
}
|
||||
|
||||
// "삥뽕빵" (Korean words composed of 2-3 tokens: [80690, 98], [167, 121, 243], and [102457, 113])
|
||||
let mut decode_stream = tokenizer.decode_stream(false);
|
||||
let output_tokens = vec![80690, 98, 167, 121, 243, 102457, 113];
|
||||
let expected_outputs = vec![
|
||||
None,
|
||||
Some("삥".to_string()),
|
||||
None,
|
||||
None,
|
||||
Some("뽕".to_string()),
|
||||
None,
|
||||
Some("빵".to_string()),
|
||||
];
|
||||
for (i, &token) in output_tokens.iter().enumerate() {
|
||||
let maybe_panic =
|
||||
panic::catch_unwind(panic::AssertUnwindSafe(|| decode_stream.step(token)));
|
||||
assert!(maybe_panic.is_ok());
|
||||
let result = maybe_panic.unwrap();
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap(), expected_outputs[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user