Fixing the stream by removing the read_index altogether. (#1716)

* Fixing the stream by removing the read_index altogether.

* Moving the test location because.. Windows.

* Ok whatever.

* Rust 1.84

* Fmt.
This commit is contained in:
Nicolas Patry
2025-01-09 17:41:15 +01:00
committed by GitHub
parent 862d1a346a
commit 0ff2ab0f64
8 changed files with 89 additions and 132 deletions

View File

@ -14,8 +14,8 @@ style:
# Check the source code is formatted correctly
check-style:
python stub.py --check
ruff check examples py_src/tokenizers tests
ruff format --check examples py_src/tokenizers tests
ruff check $(check_dirs)
ruff format --check $(check_dirs)
TESTS_RESOURCES = $(DATA_DIR)/small.txt $(DATA_DIR)/roberta.json

View File

@ -241,7 +241,7 @@ class EncodingVisualizer:
# In this case we are looking at a group/single char that is not tokenized.
# e.g. white space
css_classes.append("non-token")
css = f'''class="{' '.join(css_classes)}"'''
css = f'''class="{" ".join(css_classes)}"'''
data = ""
for key, val in data_items.items():
data += f' data-{key}="{val}"'

View File

@ -646,11 +646,6 @@ pub struct PyDecodeStream {
/// The index within the ids corresponding to the prefix so we can drain
/// correctly
prefix_index: usize,
/// We need to keep 2 prefixes.
/// Prefix is the second one that was already emitted to discard the part
/// of the text of all the ids
/// read is the prefix kept only for starting side effects of the prefix
read_index: usize,
}
#[pymethods]
@ -663,7 +658,6 @@ impl PyDecodeStream {
ids: vec![],
prefix: "".to_string(),
prefix_index: 0,
read_index: 0,
}
}
@ -676,7 +670,6 @@ impl PyDecodeStream {
&mut self.ids,
&mut self.prefix,
&mut self.prefix_index,
&mut self.read_index,
))
.into()
}

View File

@ -4,7 +4,7 @@ TESTS_DIR = tests
dir_guard=@mkdir -p $(@D)
SHARED_RESOURCES = $(DATA_DIR)/gpt2-vocab.json $(DATA_DIR)/gpt2-merges.txt $(DATA_DIR)/bert-base-uncased-vocab.txt $(DATA_DIR)/big.txt $(DATA_DIR)/small.txt $(DATA_DIR)/albert-base-v1-tokenizer.json
SHARED_RESOURCES = $(DATA_DIR)/gpt2-vocab.json $(DATA_DIR)/gpt2-merges.txt $(DATA_DIR)/bert-base-uncased-vocab.txt $(DATA_DIR)/big.txt $(DATA_DIR)/small.txt $(DATA_DIR)/albert-base-v1-tokenizer.json $(DATA_DIR)/llama-3-tokenizer.json
BENCHMARK_RESOURCES = $(SHARED_RESOURCES)
TESTS_RESOURCES = $(SHARED_RESOURCES) $(DATA_DIR)/unigram.json $(DATA_DIR)/unigram_wagahaiwa_nekodearu.txt $(DATA_DIR)/roberta.json $(DATA_DIR)/tokenizer-wiki.json $(DATA_DIR)/bert-wiki.json
@ -79,3 +79,7 @@ $(DATA_DIR)/tokenizer-wiki.json :
$(DATA_DIR)/bert-wiki.json :
$(dir_guard)
wget https://s3.amazonaws.com/models.huggingface.co/bert/anthony/doc-pipeline/tokenizer.json -O $@
$(DATA_DIR)/llama-3-tokenizer.json :
$(dir_guard)
wget https://huggingface.co/hf-internal-testing/llama3-tokenizer/resolve/main/tokenizer.json -O $@

View File

@ -199,9 +199,9 @@ impl Word {
// Make sure we are not processing an expired queue entry
let target_new_pair = (self.symbols[top.pos].c, right.c);
if !merges
if merges
.get(&target_new_pair)
.map_or(false, |(_, new_id)| *new_id == top.new_id)
.is_none_or(|(_, new_id)| *new_id != top.new_id)
{
continue;
}

View File

@ -441,7 +441,7 @@ impl TemplateProcessingBuilder {
let exist = self
.special_tokens
.as_ref()
.map_or(false, |map| map.0.contains_key(sp));
.is_some_and(|map| map.0.contains_key(sp));
match exist {
false => Some(sp),

View File

@ -1035,11 +1035,6 @@ pub struct DecodeStream<'tok, M, N, PT, PP, D> {
/// The index within the ids corresponding to the prefix so we can drain
/// correctly
prefix_index: usize,
/// We need to keep 2 prefixes.
/// Prefix is the second one that was already emitted to discard the part
/// of the text of all the ids
/// read is the prefix kept only for starting side effects of the prefix
read_index: usize,
}
#[derive(thiserror::Error, Debug)]
@ -1063,7 +1058,6 @@ where
skip_special_tokens,
prefix: "".to_string(),
prefix_index: 0,
read_index: 0,
}
}
@ -1076,7 +1070,6 @@ where
&mut self.ids,
&mut self.prefix,
&mut self.prefix_index,
&mut self.read_index,
)
}
}
@ -1089,7 +1082,6 @@ pub fn step_decode_stream<M, N, PT, PP, D>(
ids: &mut Vec<u32>,
prefix: &mut String,
prefix_index: &mut usize,
read_index: &mut usize,
) -> Result<Option<String>>
where
M: Model,
@ -1108,7 +1100,6 @@ where
let new_prefix_index = ids.len() - *prefix_index;
*ids = ids.drain(*prefix_index..).collect();
*prefix = tokenizer.decode(ids, skip_special_tokens)?;
*read_index = *prefix_index;
*prefix_index = new_prefix_index;
Ok(Some(new_text.to_string()))
} else {
@ -1563,112 +1554,3 @@ where
Ok(())
}
}
#[cfg(test)]
mod test {
#[cfg(feature = "http")]
#[test]
fn test_decoding_with_added_bpe() {
use crate::{
normalizers,
pre_tokenizers::split::{Split, SplitPattern},
AddedToken, NormalizerWrapper, PreTokenizerWrapper, SplitDelimiterBehavior, Tokenizer,
};
let mut tokenizer = Tokenizer::from_pretrained("meta-llama/Meta-Llama-3-8B", None).unwrap();
tokenizer.normalizer = Some(NormalizerWrapper::from(normalizers::ByteLevel::new()));
tokenizer.pre_tokenizer = Some(PreTokenizerWrapper::Split(
Split::new(
SplitPattern::Regex(r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+".into()),
SplitDelimiterBehavior::Isolated,
false,
)
.unwrap(),
));
tokenizer.add_tokens(&[AddedToken::from("", false).normalized(false)]);
let encoded = tokenizer
.encode("Hey! how is this token: 嗎", false)
.unwrap();
assert_eq!(
encoded.get_ids(),
[19182, 0, 1268, 602, 82, 62428, 82, 4037, 25, 220, 128256]
);
assert_eq!(
encoded.get_tokens(),
["Hey", "!", "Ġhow", "Ġi", "s", "Ġthi", "s", "Ġtoken", ":", "Ġ", ""]
);
let decoded = tokenizer.decode(encoded.get_ids(), false);
assert_eq!(decoded.unwrap(), "Hey! how is this token: 嗎");
tokenizer.add_tokens(&[AddedToken::from("д", false).normalized(true)]);
let encoded = tokenizer
.encode("Hey! how is this token: д", false)
.unwrap();
assert_eq!(
encoded.get_ids(),
[19182, 0, 1268, 602, 82, 62428, 82, 4037, 25, 220, 128257]
);
assert_eq!(
encoded.get_tokens(),
["Hey", "!", "Ġhow", "Ġi", "s", "Ġthi", "s", "Ġtoken", ":", "Ġ", "д"]
);
let decoded = tokenizer.decode(encoded.get_ids(), false);
assert_eq!(decoded.unwrap(), "Hey! how is this token: д")
}
#[cfg(feature = "http")]
#[test]
fn test_decode_stream_step_no_panic() {
use std::panic;
use crate::Tokenizer;
let tokenizer = Tokenizer::from_pretrained("meta-llama/Meta-Llama-3-8B", None).unwrap();
// "A B C D E F G H I J"
let mut decode_stream = tokenizer.decode_stream(false);
let output_tokens = vec![32, 426, 356, 423, 469, 435, 480, 473, 358, 622];
let expected_outputs = vec![
Some("A".to_string()),
Some(" B".to_string()),
Some(" C".to_string()),
Some(" D".to_string()),
Some(" E".to_string()),
Some(" F".to_string()),
Some(" G".to_string()),
Some(" H".to_string()),
Some(" I".to_string()),
Some(" J".to_string()),
];
for (i, &token) in output_tokens.iter().enumerate() {
let maybe_panic =
panic::catch_unwind(panic::AssertUnwindSafe(|| decode_stream.step(token)));
assert!(maybe_panic.is_ok());
let result = maybe_panic.unwrap();
assert!(result.is_ok());
assert_eq!(result.unwrap(), expected_outputs[i]);
}
// "삥뽕빵" (Korean words composed of 2-3 tokens: [80690, 98], [167, 121, 243], and [102457, 113])
let mut decode_stream = tokenizer.decode_stream(false);
let output_tokens = vec![80690, 98, 167, 121, 243, 102457, 113];
let expected_outputs = vec![
None,
Some("".to_string()),
None,
None,
Some("".to_string()),
None,
Some("".to_string()),
];
for (i, &token) in output_tokens.iter().enumerate() {
let maybe_panic =
panic::catch_unwind(panic::AssertUnwindSafe(|| decode_stream.step(token)));
assert!(maybe_panic.is_ok());
let result = maybe_panic.unwrap();
assert!(result.is_ok());
assert_eq!(result.unwrap(), expected_outputs[i]);
}
}
}

View File

@ -0,0 +1,78 @@
use tokenizers::{
normalizers,
pre_tokenizers::split::{Split, SplitPattern},
AddedToken, NormalizerWrapper, PreTokenizerWrapper, SplitDelimiterBehavior, Tokenizer,
};
#[test]
fn test_decoding_with_added_bpe() {
let mut tokenizer = Tokenizer::from_file("data/llama-3-tokenizer.json").unwrap();
tokenizer.with_normalizer(Some(NormalizerWrapper::from(normalizers::ByteLevel::new())));
tokenizer.with_pre_tokenizer(Some(PreTokenizerWrapper::Split(
Split::new(
SplitPattern::Regex(r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+".into()),
SplitDelimiterBehavior::Isolated,
false,
)
.unwrap(),
)));
tokenizer.add_tokens(&[AddedToken::from("", false).normalized(false)]);
let encoded = tokenizer
.encode("Hey! how is this token: 嗎", false)
.unwrap();
assert_eq!(
encoded.get_ids(),
[19182, 0, 1268, 602, 82, 62428, 82, 4037, 25, 220, 128256]
);
assert_eq!(
encoded.get_tokens(),
["Hey", "!", "Ġhow", "Ġi", "s", "Ġthi", "s", "Ġtoken", ":", "Ġ", ""]
);
let decoded = tokenizer.decode(encoded.get_ids(), false);
assert_eq!(decoded.unwrap(), "Hey! how is this token: 嗎");
tokenizer.add_tokens(&[AddedToken::from("д", false).normalized(true)]);
let encoded = tokenizer
.encode("Hey! how is this token: д", false)
.unwrap();
assert_eq!(
encoded.get_ids(),
[19182, 0, 1268, 602, 82, 62428, 82, 4037, 25, 220, 128257]
);
assert_eq!(
encoded.get_tokens(),
["Hey", "!", "Ġhow", "Ġi", "s", "Ġthi", "s", "Ġtoken", ":", "Ġ", "д"]
);
let decoded = tokenizer.decode(encoded.get_ids(), false);
assert_eq!(decoded.unwrap(), "Hey! how is this token: д")
}
#[test]
fn test_decode_stream_step_no_panic() {
let tokenizer = Tokenizer::from_file("data/llama-3-tokenizer.json").unwrap();
// "A B C D E F G H I J"
let mut decode_stream = tokenizer.decode_stream(false);
assert_eq!(decode_stream.step(32).unwrap(), Some("A".to_string()));
assert_eq!(decode_stream.step(426).unwrap(), Some(" B".to_string()));
assert_eq!(decode_stream.step(356).unwrap(), Some(" C".to_string()));
assert_eq!(decode_stream.step(423).unwrap(), Some(" D".to_string()));
assert_eq!(decode_stream.step(469).unwrap(), Some(" E".to_string()));
assert_eq!(decode_stream.step(435).unwrap(), Some(" F".to_string()));
assert_eq!(decode_stream.step(480).unwrap(), Some(" G".to_string()));
assert_eq!(decode_stream.step(473).unwrap(), Some(" H".to_string()));
assert_eq!(decode_stream.step(358).unwrap(), Some(" I".to_string()));
assert_eq!(decode_stream.step(622).unwrap(), Some(" J".to_string()));
// for (i, &token) in output_tokens.iter().enumerate() {}
// "삥뽕빵" (Korean words composed of 2-3 tokens: [80690, 98], [167, 121, 243], and [102457, 113])
let mut decode_stream = tokenizer.decode_stream(false);
assert_eq!(decode_stream.step(80690).unwrap(), None);
assert_eq!(decode_stream.step(98).unwrap(), Some("".to_string()));
assert_eq!(decode_stream.step(167).unwrap(), None);
assert_eq!(decode_stream.step(121).unwrap(), None);
assert_eq!(decode_stream.step(243).unwrap(), Some("".to_string()));
assert_eq!(decode_stream.step(102457).unwrap(), None);
assert_eq!(decode_stream.step(113).unwrap(), Some("".to_string()));
}