mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Fixing the stream by removing the read_index altogether. (#1716)
* Fixing the stream by removing the read_index altogether. * Moving the test location because.. Windows. * Ok whatever. * Rust 1.84 * Fmt.
This commit is contained in:
@ -14,8 +14,8 @@ style:
|
||||
# Check the source code is formatted correctly
|
||||
check-style:
|
||||
python stub.py --check
|
||||
ruff check examples py_src/tokenizers tests
|
||||
ruff format --check examples py_src/tokenizers tests
|
||||
ruff check $(check_dirs)
|
||||
ruff format --check $(check_dirs)
|
||||
|
||||
TESTS_RESOURCES = $(DATA_DIR)/small.txt $(DATA_DIR)/roberta.json
|
||||
|
||||
|
@ -241,7 +241,7 @@ class EncodingVisualizer:
|
||||
# In this case we are looking at a group/single char that is not tokenized.
|
||||
# e.g. white space
|
||||
css_classes.append("non-token")
|
||||
css = f'''class="{' '.join(css_classes)}"'''
|
||||
css = f'''class="{" ".join(css_classes)}"'''
|
||||
data = ""
|
||||
for key, val in data_items.items():
|
||||
data += f' data-{key}="{val}"'
|
||||
|
@ -646,11 +646,6 @@ pub struct PyDecodeStream {
|
||||
/// The index within the ids corresponding to the prefix so we can drain
|
||||
/// correctly
|
||||
prefix_index: usize,
|
||||
/// We need to keep 2 prefixes.
|
||||
/// Prefix is the second one that was already emitted to discard the part
|
||||
/// of the text of all the ids
|
||||
/// read is the prefix kept only for starting side effects of the prefix
|
||||
read_index: usize,
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
@ -663,7 +658,6 @@ impl PyDecodeStream {
|
||||
ids: vec![],
|
||||
prefix: "".to_string(),
|
||||
prefix_index: 0,
|
||||
read_index: 0,
|
||||
}
|
||||
}
|
||||
|
||||
@ -676,7 +670,6 @@ impl PyDecodeStream {
|
||||
&mut self.ids,
|
||||
&mut self.prefix,
|
||||
&mut self.prefix_index,
|
||||
&mut self.read_index,
|
||||
))
|
||||
.into()
|
||||
}
|
||||
|
@ -4,7 +4,7 @@ TESTS_DIR = tests
|
||||
|
||||
dir_guard=@mkdir -p $(@D)
|
||||
|
||||
SHARED_RESOURCES = $(DATA_DIR)/gpt2-vocab.json $(DATA_DIR)/gpt2-merges.txt $(DATA_DIR)/bert-base-uncased-vocab.txt $(DATA_DIR)/big.txt $(DATA_DIR)/small.txt $(DATA_DIR)/albert-base-v1-tokenizer.json
|
||||
SHARED_RESOURCES = $(DATA_DIR)/gpt2-vocab.json $(DATA_DIR)/gpt2-merges.txt $(DATA_DIR)/bert-base-uncased-vocab.txt $(DATA_DIR)/big.txt $(DATA_DIR)/small.txt $(DATA_DIR)/albert-base-v1-tokenizer.json $(DATA_DIR)/llama-3-tokenizer.json
|
||||
BENCHMARK_RESOURCES = $(SHARED_RESOURCES)
|
||||
TESTS_RESOURCES = $(SHARED_RESOURCES) $(DATA_DIR)/unigram.json $(DATA_DIR)/unigram_wagahaiwa_nekodearu.txt $(DATA_DIR)/roberta.json $(DATA_DIR)/tokenizer-wiki.json $(DATA_DIR)/bert-wiki.json
|
||||
|
||||
@ -79,3 +79,7 @@ $(DATA_DIR)/tokenizer-wiki.json :
|
||||
$(DATA_DIR)/bert-wiki.json :
|
||||
$(dir_guard)
|
||||
wget https://s3.amazonaws.com/models.huggingface.co/bert/anthony/doc-pipeline/tokenizer.json -O $@
|
||||
|
||||
$(DATA_DIR)/llama-3-tokenizer.json :
|
||||
$(dir_guard)
|
||||
wget https://huggingface.co/hf-internal-testing/llama3-tokenizer/resolve/main/tokenizer.json -O $@
|
||||
|
@ -199,9 +199,9 @@ impl Word {
|
||||
|
||||
// Make sure we are not processing an expired queue entry
|
||||
let target_new_pair = (self.symbols[top.pos].c, right.c);
|
||||
if !merges
|
||||
if merges
|
||||
.get(&target_new_pair)
|
||||
.map_or(false, |(_, new_id)| *new_id == top.new_id)
|
||||
.is_none_or(|(_, new_id)| *new_id != top.new_id)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
@ -441,7 +441,7 @@ impl TemplateProcessingBuilder {
|
||||
let exist = self
|
||||
.special_tokens
|
||||
.as_ref()
|
||||
.map_or(false, |map| map.0.contains_key(sp));
|
||||
.is_some_and(|map| map.0.contains_key(sp));
|
||||
|
||||
match exist {
|
||||
false => Some(sp),
|
||||
|
@ -1035,11 +1035,6 @@ pub struct DecodeStream<'tok, M, N, PT, PP, D> {
|
||||
/// The index within the ids corresponding to the prefix so we can drain
|
||||
/// correctly
|
||||
prefix_index: usize,
|
||||
/// We need to keep 2 prefixes.
|
||||
/// Prefix is the second one that was already emitted to discard the part
|
||||
/// of the text of all the ids
|
||||
/// read is the prefix kept only for starting side effects of the prefix
|
||||
read_index: usize,
|
||||
}
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
@ -1063,7 +1058,6 @@ where
|
||||
skip_special_tokens,
|
||||
prefix: "".to_string(),
|
||||
prefix_index: 0,
|
||||
read_index: 0,
|
||||
}
|
||||
}
|
||||
|
||||
@ -1076,7 +1070,6 @@ where
|
||||
&mut self.ids,
|
||||
&mut self.prefix,
|
||||
&mut self.prefix_index,
|
||||
&mut self.read_index,
|
||||
)
|
||||
}
|
||||
}
|
||||
@ -1089,7 +1082,6 @@ pub fn step_decode_stream<M, N, PT, PP, D>(
|
||||
ids: &mut Vec<u32>,
|
||||
prefix: &mut String,
|
||||
prefix_index: &mut usize,
|
||||
read_index: &mut usize,
|
||||
) -> Result<Option<String>>
|
||||
where
|
||||
M: Model,
|
||||
@ -1108,7 +1100,6 @@ where
|
||||
let new_prefix_index = ids.len() - *prefix_index;
|
||||
*ids = ids.drain(*prefix_index..).collect();
|
||||
*prefix = tokenizer.decode(ids, skip_special_tokens)?;
|
||||
*read_index = *prefix_index;
|
||||
*prefix_index = new_prefix_index;
|
||||
Ok(Some(new_text.to_string()))
|
||||
} else {
|
||||
@ -1563,112 +1554,3 @@ where
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
#[cfg(feature = "http")]
|
||||
#[test]
|
||||
fn test_decoding_with_added_bpe() {
|
||||
use crate::{
|
||||
normalizers,
|
||||
pre_tokenizers::split::{Split, SplitPattern},
|
||||
AddedToken, NormalizerWrapper, PreTokenizerWrapper, SplitDelimiterBehavior, Tokenizer,
|
||||
};
|
||||
|
||||
let mut tokenizer = Tokenizer::from_pretrained("meta-llama/Meta-Llama-3-8B", None).unwrap();
|
||||
tokenizer.normalizer = Some(NormalizerWrapper::from(normalizers::ByteLevel::new()));
|
||||
tokenizer.pre_tokenizer = Some(PreTokenizerWrapper::Split(
|
||||
Split::new(
|
||||
SplitPattern::Regex(r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+".into()),
|
||||
SplitDelimiterBehavior::Isolated,
|
||||
false,
|
||||
)
|
||||
.unwrap(),
|
||||
));
|
||||
tokenizer.add_tokens(&[AddedToken::from("嗎", false).normalized(false)]);
|
||||
let encoded = tokenizer
|
||||
.encode("Hey! how is this token: 嗎", false)
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
encoded.get_ids(),
|
||||
[19182, 0, 1268, 602, 82, 62428, 82, 4037, 25, 220, 128256]
|
||||
);
|
||||
assert_eq!(
|
||||
encoded.get_tokens(),
|
||||
["Hey", "!", "Ġhow", "Ġi", "s", "Ġthi", "s", "Ġtoken", ":", "Ġ", "嗎"]
|
||||
);
|
||||
|
||||
let decoded = tokenizer.decode(encoded.get_ids(), false);
|
||||
assert_eq!(decoded.unwrap(), "Hey! how is this token: 嗎");
|
||||
|
||||
tokenizer.add_tokens(&[AddedToken::from("д", false).normalized(true)]);
|
||||
let encoded = tokenizer
|
||||
.encode("Hey! how is this token: д", false)
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
encoded.get_ids(),
|
||||
[19182, 0, 1268, 602, 82, 62428, 82, 4037, 25, 220, 128257]
|
||||
);
|
||||
assert_eq!(
|
||||
encoded.get_tokens(),
|
||||
["Hey", "!", "Ġhow", "Ġi", "s", "Ġthi", "s", "Ġtoken", ":", "Ġ", "д"]
|
||||
);
|
||||
let decoded = tokenizer.decode(encoded.get_ids(), false);
|
||||
assert_eq!(decoded.unwrap(), "Hey! how is this token: д")
|
||||
}
|
||||
|
||||
#[cfg(feature = "http")]
|
||||
#[test]
|
||||
fn test_decode_stream_step_no_panic() {
|
||||
use std::panic;
|
||||
|
||||
use crate::Tokenizer;
|
||||
|
||||
let tokenizer = Tokenizer::from_pretrained("meta-llama/Meta-Llama-3-8B", None).unwrap();
|
||||
|
||||
// "A B C D E F G H I J"
|
||||
let mut decode_stream = tokenizer.decode_stream(false);
|
||||
let output_tokens = vec![32, 426, 356, 423, 469, 435, 480, 473, 358, 622];
|
||||
let expected_outputs = vec![
|
||||
Some("A".to_string()),
|
||||
Some(" B".to_string()),
|
||||
Some(" C".to_string()),
|
||||
Some(" D".to_string()),
|
||||
Some(" E".to_string()),
|
||||
Some(" F".to_string()),
|
||||
Some(" G".to_string()),
|
||||
Some(" H".to_string()),
|
||||
Some(" I".to_string()),
|
||||
Some(" J".to_string()),
|
||||
];
|
||||
for (i, &token) in output_tokens.iter().enumerate() {
|
||||
let maybe_panic =
|
||||
panic::catch_unwind(panic::AssertUnwindSafe(|| decode_stream.step(token)));
|
||||
assert!(maybe_panic.is_ok());
|
||||
let result = maybe_panic.unwrap();
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap(), expected_outputs[i]);
|
||||
}
|
||||
|
||||
// "삥뽕빵" (Korean words composed of 2-3 tokens: [80690, 98], [167, 121, 243], and [102457, 113])
|
||||
let mut decode_stream = tokenizer.decode_stream(false);
|
||||
let output_tokens = vec![80690, 98, 167, 121, 243, 102457, 113];
|
||||
let expected_outputs = vec![
|
||||
None,
|
||||
Some("삥".to_string()),
|
||||
None,
|
||||
None,
|
||||
Some("뽕".to_string()),
|
||||
None,
|
||||
Some("빵".to_string()),
|
||||
];
|
||||
for (i, &token) in output_tokens.iter().enumerate() {
|
||||
let maybe_panic =
|
||||
panic::catch_unwind(panic::AssertUnwindSafe(|| decode_stream.step(token)));
|
||||
assert!(maybe_panic.is_ok());
|
||||
let result = maybe_panic.unwrap();
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap(), expected_outputs[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
78
tokenizers/tests/stream.rs
Normal file
78
tokenizers/tests/stream.rs
Normal file
@ -0,0 +1,78 @@
|
||||
use tokenizers::{
|
||||
normalizers,
|
||||
pre_tokenizers::split::{Split, SplitPattern},
|
||||
AddedToken, NormalizerWrapper, PreTokenizerWrapper, SplitDelimiterBehavior, Tokenizer,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn test_decoding_with_added_bpe() {
|
||||
let mut tokenizer = Tokenizer::from_file("data/llama-3-tokenizer.json").unwrap();
|
||||
tokenizer.with_normalizer(Some(NormalizerWrapper::from(normalizers::ByteLevel::new())));
|
||||
tokenizer.with_pre_tokenizer(Some(PreTokenizerWrapper::Split(
|
||||
Split::new(
|
||||
SplitPattern::Regex(r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+".into()),
|
||||
SplitDelimiterBehavior::Isolated,
|
||||
false,
|
||||
)
|
||||
.unwrap(),
|
||||
)));
|
||||
tokenizer.add_tokens(&[AddedToken::from("嗎", false).normalized(false)]);
|
||||
let encoded = tokenizer
|
||||
.encode("Hey! how is this token: 嗎", false)
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
encoded.get_ids(),
|
||||
[19182, 0, 1268, 602, 82, 62428, 82, 4037, 25, 220, 128256]
|
||||
);
|
||||
assert_eq!(
|
||||
encoded.get_tokens(),
|
||||
["Hey", "!", "Ġhow", "Ġi", "s", "Ġthi", "s", "Ġtoken", ":", "Ġ", "嗎"]
|
||||
);
|
||||
|
||||
let decoded = tokenizer.decode(encoded.get_ids(), false);
|
||||
assert_eq!(decoded.unwrap(), "Hey! how is this token: 嗎");
|
||||
|
||||
tokenizer.add_tokens(&[AddedToken::from("д", false).normalized(true)]);
|
||||
let encoded = tokenizer
|
||||
.encode("Hey! how is this token: д", false)
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
encoded.get_ids(),
|
||||
[19182, 0, 1268, 602, 82, 62428, 82, 4037, 25, 220, 128257]
|
||||
);
|
||||
assert_eq!(
|
||||
encoded.get_tokens(),
|
||||
["Hey", "!", "Ġhow", "Ġi", "s", "Ġthi", "s", "Ġtoken", ":", "Ġ", "д"]
|
||||
);
|
||||
let decoded = tokenizer.decode(encoded.get_ids(), false);
|
||||
assert_eq!(decoded.unwrap(), "Hey! how is this token: д")
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decode_stream_step_no_panic() {
|
||||
let tokenizer = Tokenizer::from_file("data/llama-3-tokenizer.json").unwrap();
|
||||
|
||||
// "A B C D E F G H I J"
|
||||
let mut decode_stream = tokenizer.decode_stream(false);
|
||||
assert_eq!(decode_stream.step(32).unwrap(), Some("A".to_string()));
|
||||
assert_eq!(decode_stream.step(426).unwrap(), Some(" B".to_string()));
|
||||
assert_eq!(decode_stream.step(356).unwrap(), Some(" C".to_string()));
|
||||
assert_eq!(decode_stream.step(423).unwrap(), Some(" D".to_string()));
|
||||
assert_eq!(decode_stream.step(469).unwrap(), Some(" E".to_string()));
|
||||
assert_eq!(decode_stream.step(435).unwrap(), Some(" F".to_string()));
|
||||
assert_eq!(decode_stream.step(480).unwrap(), Some(" G".to_string()));
|
||||
assert_eq!(decode_stream.step(473).unwrap(), Some(" H".to_string()));
|
||||
assert_eq!(decode_stream.step(358).unwrap(), Some(" I".to_string()));
|
||||
assert_eq!(decode_stream.step(622).unwrap(), Some(" J".to_string()));
|
||||
// for (i, &token) in output_tokens.iter().enumerate() {}
|
||||
|
||||
// "삥뽕빵" (Korean words composed of 2-3 tokens: [80690, 98], [167, 121, 243], and [102457, 113])
|
||||
let mut decode_stream = tokenizer.decode_stream(false);
|
||||
assert_eq!(decode_stream.step(80690).unwrap(), None);
|
||||
assert_eq!(decode_stream.step(98).unwrap(), Some("삥".to_string()));
|
||||
assert_eq!(decode_stream.step(167).unwrap(), None);
|
||||
assert_eq!(decode_stream.step(121).unwrap(), None);
|
||||
assert_eq!(decode_stream.step(243).unwrap(), Some("뽕".to_string()));
|
||||
assert_eq!(decode_stream.step(102457).unwrap(), None);
|
||||
assert_eq!(decode_stream.step(113).unwrap(), Some("빵".to_string()));
|
||||
}
|
Reference in New Issue
Block a user