* clippy

* fmtr

* rutc?

* fix onig issue

* up

* decode stream default

* jump a release for cargo audit ...

* more cliippy stuff

* clippy?

* proper style

* fmt
This commit is contained in:
Arthur
2025-05-27 11:30:32 +02:00
committed by GitHub
parent 23e7e42adf
commit 01f8bc834c
18 changed files with 46 additions and 59 deletions

View File

@ -10,7 +10,7 @@ jobs:
build: build:
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
env: env:
MACOSX_DEPLOYMENT_TARGET: 10.11 MACOSX_DEPLOYMENT_TARGET: 10.12
strategy: strategy:
matrix: matrix:
os: [ubuntu-latest, windows-latest, macOS-latest] os: [ubuntu-latest, windows-latest, macOS-latest]

View File

@ -14,8 +14,8 @@ serde = { version = "1.0", features = ["rc", "derive"] }
serde_json = "1.0" serde_json = "1.0"
libc = "0.2" libc = "0.2"
env_logger = "0.11" env_logger = "0.11"
pyo3 = { version = "0.23", features = ["abi3", "abi3-py39", "py-clone"] } pyo3 = { version = "0.24", features = ["abi3", "abi3-py39", "py-clone"] }
numpy = "0.23" numpy = "0.24"
ndarray = "0.16" ndarray = "0.16"
itertools = "0.12" itertools = "0.12"
@ -24,7 +24,7 @@ path = "../../tokenizers"
[dev-dependencies] [dev-dependencies]
tempfile = "3.10" tempfile = "3.10"
pyo3 = { version = "0.23", features = ["auto-initialize"] } pyo3 = { version = "0.24", features = ["auto-initialize"] }
[features] [features]
default = ["pyo3/extension-module"] default = ["pyo3/extension-module"]

View File

@ -33,7 +33,7 @@ class BPEDecoder(Decoder):
Args: Args:
suffix (:obj:`str`, `optional`, defaults to :obj:`</w>`): suffix (:obj:`str`, `optional`, defaults to :obj:`</w>`):
The suffix that was used to caracterize an end-of-word. This suffix will The suffix that was used to characterize an end-of-word. This suffix will
be replaced by whitespaces during the decoding be replaced by whitespaces during the decoding
""" """
def __init__(self, suffix="</w>"): def __init__(self, suffix="</w>"):

View File

@ -7,7 +7,7 @@ use tokenizers::Tokenizer;
pub fn llama3(c: &mut Criterion) { pub fn llama3(c: &mut Criterion) {
let data = std::fs::read_to_string("data/big.txt").unwrap(); let data = std::fs::read_to_string("data/big.txt").unwrap();
let mut group = c.benchmark_group("llama3-encode"); let mut group = c.benchmark_group("llama3-encode");
group.throughput(Throughput::Bytes(data.bytes().len() as u64)); group.throughput(Throughput::Bytes(data.len() as u64));
group.bench_function("llama3-offsets", |b| { group.bench_function("llama3-offsets", |b| {
let tokenizer = let tokenizer =
Tokenizer::from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct", None).unwrap(); Tokenizer::from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct", None).unwrap();

View File

@ -28,11 +28,7 @@ impl Decoder for ByteFallback {
for token in tokens { for token in tokens {
let bytes = if token.len() == 6 && token.starts_with("<0x") && token.ends_with('>') { let bytes = if token.len() == 6 && token.starts_with("<0x") && token.ends_with('>') {
if let Ok(byte) = u8::from_str_radix(&token[3..5], 16) { u8::from_str_radix(&token[3..5], 16).ok()
Some(byte)
} else {
None
}
} else { } else {
None None
}; };

View File

@ -35,7 +35,7 @@ impl Serialize for OrderedVocabIter<'_> {
{ {
// There could be holes so max + 1 is more correct than vocab_r.len() // There could be holes so max + 1 is more correct than vocab_r.len()
let mut holes = vec![]; let mut holes = vec![];
let result = if let Some(max) = self.vocab_r.iter().map(|(key, _)| key).max() { let result = if let Some(max) = self.vocab_r.keys().max() {
let iter = (0..*max + 1).filter_map(|i| { let iter = (0..*max + 1).filter_map(|i| {
if let Some(token) = self.vocab_r.get(&i) { if let Some(token) = self.vocab_r.get(&i) {
Some((token, i)) Some((token, i))
@ -50,7 +50,7 @@ impl Serialize for OrderedVocabIter<'_> {
}; };
if !holes.is_empty() { if !holes.is_empty() {
warn!("The OrderedVocab you are attempting to save contains holes for indices {:?}, your vocabulary could be corrupted !", holes); warn!("The OrderedVocab you are attempting to save contains holes for indices {holes:?}, your vocabulary could be corrupted !");
println!("The OrderedVocab you are attempting to save contains holes for indices {holes:?}, your vocabulary could be corrupted !"); println!("The OrderedVocab you are attempting to save contains holes for indices {holes:?}, your vocabulary could be corrupted !");
} }
result result

View File

@ -313,7 +313,7 @@ impl Unigram {
&& node.id == self.unk_id.ok_or(UnigramError::MissingUnkId)? && node.id == self.unk_id.ok_or(UnigramError::MissingUnkId)?
{ {
token.push( token.push(
String::from_utf8(sentence[starts_at..ends_at].as_bytes().to_vec()).unwrap(), String::from_utf8((sentence.as_bytes()[starts_at..ends_at]).to_vec()).unwrap(),
); );
} else { } else {
if !token.is_empty() { if !token.is_empty() {
@ -322,7 +322,7 @@ impl Unigram {
token = vec![]; token = vec![];
} }
results.push( results.push(
String::from_utf8(sentence[starts_at..ends_at].as_bytes().to_vec()).unwrap(), String::from_utf8((sentence.as_bytes()[starts_at..ends_at]).to_vec()).unwrap(),
); );
} }
ends_at = starts_at; ends_at = starts_at;

View File

@ -35,7 +35,7 @@ impl Normalizer for ByteLevel {
let mut i = 0; let mut i = 0;
for cur_char in s.chars() { for cur_char in s.chars() {
let size = cur_char.len_utf8(); let size = cur_char.len_utf8();
let bytes = s[i..i + size].as_bytes(); let bytes = &s.as_bytes()[i..i + size];
i += size; i += size;
transformations.extend( transformations.extend(
bytes bytes

View File

@ -135,7 +135,7 @@ impl PreTokenizer for ByteLevel {
let mut i = 0; let mut i = 0;
for cur_char in s.chars() { for cur_char in s.chars() {
let size = cur_char.len_utf8(); let size = cur_char.len_utf8();
let bytes = s[i..i + size].as_bytes(); let bytes = &s.as_bytes()[i..i + size];
i += size; i += size;
transformations.extend( transformations.extend(
bytes bytes

View File

@ -65,9 +65,9 @@ impl PostProcessor for BertProcessing {
let ids = [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat(); let ids = [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat();
let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat(); let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat();
let tokens = [ let tokens = [
&[self.cls.0.clone()], std::slice::from_ref(&self.cls.0),
encoding.get_tokens(), encoding.get_tokens(),
&[self.sep.0.clone()], std::slice::from_ref(&self.sep.0),
] ]
.concat(); .concat();
let words = [&[None], encoding.get_word_ids(), &[None]].concat(); let words = [&[None], encoding.get_word_ids(), &[None]].concat();
@ -95,9 +95,9 @@ impl PostProcessor for BertProcessing {
[&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat(); [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat();
let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat(); let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat();
let tokens = [ let tokens = [
&[self.cls.0.clone()], std::slice::from_ref(&self.cls.0),
encoding.get_tokens(), encoding.get_tokens(),
&[self.sep.0.clone()], std::slice::from_ref(&self.sep.0),
] ]
.concat(); .concat();
let words = [&[None], encoding.get_word_ids(), &[None]].concat(); let words = [&[None], encoding.get_word_ids(), &[None]].concat();
@ -130,7 +130,8 @@ impl PostProcessor for BertProcessing {
} else { } else {
let pair_ids = [encoding.get_ids(), &[self.sep.1]].concat(); let pair_ids = [encoding.get_ids(), &[self.sep.1]].concat();
let pair_type_ids = [encoding.get_type_ids(), &[1]].concat(); let pair_type_ids = [encoding.get_type_ids(), &[1]].concat();
let pair_tokens = [encoding.get_tokens(), &[self.sep.0.clone()]].concat(); let pair_tokens =
[encoding.get_tokens(), std::slice::from_ref(&self.sep.0)].concat();
let pair_words = [encoding.get_word_ids(), &[None]].concat(); let pair_words = [encoding.get_word_ids(), &[None]].concat();
let pair_offsets = [encoding.get_offsets(), &[(0, 0)]].concat(); let pair_offsets = [encoding.get_offsets(), &[(0, 0)]].concat();
let pair_special_tokens = let pair_special_tokens =
@ -155,7 +156,8 @@ impl PostProcessor for BertProcessing {
let pair_ids = [encoding.get_ids(), &[self.sep.1]].concat(); let pair_ids = [encoding.get_ids(), &[self.sep.1]].concat();
let pair_type_ids = [encoding.get_type_ids(), &[1]].concat(); let pair_type_ids = [encoding.get_type_ids(), &[1]].concat();
let pair_tokens = let pair_tokens =
[encoding.get_tokens(), &[self.sep.0.clone()]].concat(); [encoding.get_tokens(), std::slice::from_ref(&self.sep.0)]
.concat();
let pair_words = [encoding.get_word_ids(), &[None]].concat(); let pair_words = [encoding.get_word_ids(), &[None]].concat();
let pair_offsets = [encoding.get_offsets(), &[(0, 0)]].concat(); let pair_offsets = [encoding.get_offsets(), &[(0, 0)]].concat();
let pair_special_tokens = let pair_special_tokens =

View File

@ -95,9 +95,9 @@ impl PostProcessor for RobertaProcessing {
let ids = [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat(); let ids = [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat();
let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat(); let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat();
let tokens = [ let tokens = [
&[self.cls.0.clone()], std::slice::from_ref(&self.cls.0),
encoding.get_tokens(), encoding.get_tokens(),
&[self.sep.0.clone()], std::slice::from_ref(&self.sep.0),
] ]
.concat(); .concat();
let words = [&[None], encoding.get_word_ids(), &[None]].concat(); let words = [&[None], encoding.get_word_ids(), &[None]].concat();
@ -125,9 +125,9 @@ impl PostProcessor for RobertaProcessing {
[&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat(); [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat();
let type_ids = vec![0; encoding.get_ids().len() + 2]; let type_ids = vec![0; encoding.get_ids().len() + 2];
let tokens = [ let tokens = [
&[self.cls.0.clone()], std::slice::from_ref(&self.cls.0),
encoding.get_tokens(), encoding.get_tokens(),
&[self.sep.0.clone()], std::slice::from_ref(&self.sep.0),
] ]
.concat(); .concat();
let words = [&[None], encoding.get_word_ids(), &[None]].concat(); let words = [&[None], encoding.get_word_ids(), &[None]].concat();
@ -161,9 +161,9 @@ impl PostProcessor for RobertaProcessing {
let pair_ids = [&[self.sep.1], encoding.get_ids(), &[self.sep.1]].concat(); let pair_ids = [&[self.sep.1], encoding.get_ids(), &[self.sep.1]].concat();
let pair_type_ids = vec![0; encoding.get_ids().len() + 2]; let pair_type_ids = vec![0; encoding.get_ids().len() + 2];
let pair_tokens = [ let pair_tokens = [
&[self.sep.0.clone()], std::slice::from_ref(&self.sep.0),
encoding.get_tokens(), encoding.get_tokens(),
&[self.sep.0.clone()], std::slice::from_ref(&self.sep.0),
] ]
.concat(); .concat();
let pair_words = [&[None], encoding.get_word_ids(), &[None]].concat(); let pair_words = [&[None], encoding.get_word_ids(), &[None]].concat();
@ -191,9 +191,9 @@ impl PostProcessor for RobertaProcessing {
[&[self.sep.1], encoding.get_ids(), &[self.sep.1]].concat(); [&[self.sep.1], encoding.get_ids(), &[self.sep.1]].concat();
let pair_type_ids = vec![0; encoding.get_ids().len() + 2]; let pair_type_ids = vec![0; encoding.get_ids().len() + 2];
let pair_tokens = [ let pair_tokens = [
&[self.sep.0.clone()], std::slice::from_ref(&self.sep.0),
encoding.get_tokens(), encoding.get_tokens(),
&[self.sep.0.clone()], std::slice::from_ref(&self.sep.0),
] ]
.concat(); .concat();
let pair_words = let pair_words =

View File

@ -565,16 +565,16 @@ impl TemplateProcessing {
let encoding = Encoding::new( let encoding = Encoding::new(
tok.ids.clone(), tok.ids.clone(),
std::iter::repeat(*type_id).take(len).collect(), std::iter::repeat_n(*type_id, len).collect(),
tok.tokens.clone(), tok.tokens.clone(),
// words // words
std::iter::repeat(None).take(len).collect(), std::iter::repeat_n(None, len).collect(),
// offsets // offsets
std::iter::repeat((0, 0)).take(len).collect(), std::iter::repeat_n((0, 0), len).collect(),
// special_tokens_mask // special_tokens_mask
std::iter::repeat(1).take(len).collect(), std::iter::repeat_n(1, len).collect(),
// attention_mask // attention_mask
std::iter::repeat(1).take(len).collect(), std::iter::repeat_n(1, len).collect(),
// overflowing // overflowing
vec![], vec![],
// sequence_range // sequence_range

View File

@ -668,7 +668,7 @@ mod tests {
// Also adds tokens already covered by the model // Also adds tokens already covered by the model
let added_token = AddedToken::from("test", false); let added_token = AddedToken::from("test", false);
assert_eq!( assert_eq!(
vocab.add_tokens(&[added_token.clone()], &model, normalizer), vocab.add_tokens(std::slice::from_ref(&added_token), &model, normalizer),
1 1
); );
assert_eq!(vocab.len(), 3); assert_eq!(vocab.len(), 3);

View File

@ -139,7 +139,7 @@ impl Encoding {
for seq_id in 0..self.n_sequences() { for seq_id in 0..self.n_sequences() {
let range = self.sequence_range(seq_id); let range = self.sequence_range(seq_id);
let seq_len = range.len(); let seq_len = range.len();
sequences.splice(range, std::iter::repeat(Some(seq_id)).take(seq_len)); sequences.splice(range, std::iter::repeat_n(Some(seq_id), seq_len));
} }
sequences sequences
} }

View File

@ -328,9 +328,7 @@ impl NormalizedString {
}, },
}; };
trace!( trace!(
"===== transform_range call with {:?} (initial_offset: {}) =====", "===== transform_range call with {n_range:?} (initial_offset: {initial_offset}) ====="
n_range,
initial_offset
); );
// Retrieve the original characters that are being replaced. This let us // Retrieve the original characters that are being replaced. This let us
@ -386,9 +384,7 @@ impl NormalizedString {
let replaced_char_size_change = c.len_utf8() as isize - replaced_char_size as isize; let replaced_char_size_change = c.len_utf8() as isize - replaced_char_size as isize;
if let Some(ref replaced_char) = replaced_char { if let Some(ref replaced_char) = replaced_char {
trace!( trace!(
"Replacing char {:?} - with a change in size: {}", "Replacing char {replaced_char:?} - with a change in size: {replaced_char_size_change}"
replaced_char,
replaced_char_size_change
); );
} }
@ -401,12 +397,12 @@ impl NormalizedString {
} else { } else {
0 0
}; };
trace!("Total bytes to remove: {}", total_bytes_to_remove); trace!("Total bytes to remove: {total_bytes_to_remove}");
// Keep track of the changes for next offsets // Keep track of the changes for next offsets
offset += replaced_char_size as isize; offset += replaced_char_size as isize;
offset += total_bytes_to_remove as isize; offset += total_bytes_to_remove as isize;
trace!("New offset: {}", offset); trace!("New offset: {offset}");
trace!("New normalized alignment: {}x {:?}", c.len_utf8(), align); trace!("New normalized alignment: {}x {:?}", c.len_utf8(), align);
alignments.extend((0..c.len_utf8()).map(|_| align)); alignments.extend((0..c.len_utf8()).map(|_| align));

View File

@ -159,9 +159,7 @@ where
if rid != token.id { if rid != token.id {
warn!( warn!(
"Warning: Token '{}' was expected to have ID '{}' but was given ID '{}'", "Warning: Token '{}' was expected to have ID '{}' but was given ID '{}'",
token.token.content, token.token.content, token.id, rid
token.id,
rid.to_string()
); );
} }
} }

View File

@ -36,14 +36,13 @@ pub fn from_pretrained<S: AsRef<str>>(
let valid_chars_stringified = valid_chars let valid_chars_stringified = valid_chars
.iter() .iter()
.fold(vec![], |mut buf, x| { .fold(vec![], |mut buf, x| {
buf.push(format!("'{}'", x)); buf.push(format!("'{x}'"));
buf buf
}) })
.join(", "); // "'/', '-', '_', '.'" .join(", "); // "'/', '-', '_', '.'"
if !valid { if !valid {
return Err(format!( return Err(format!(
"Model \"{}\" contains invalid characters, expected only alphanumeric or {valid_chars_stringified}", "Model \"{identifier}\" contains invalid characters, expected only alphanumeric or {valid_chars_stringified}"
identifier
) )
.into()); .into());
} }
@ -53,8 +52,7 @@ pub fn from_pretrained<S: AsRef<str>>(
let valid_revision = revision.chars().all(is_valid_char); let valid_revision = revision.chars().all(is_valid_char);
if !valid_revision { if !valid_revision {
return Err(format!( return Err(format!(
"Revision \"{}\" contains invalid characters, expected only alphanumeric or {valid_chars_stringified}", "Revision \"{revision}\" contains invalid characters, expected only alphanumeric or {valid_chars_stringified}"
revision
) )
.into()); .into());
} }

View File

@ -2,18 +2,15 @@ pub(crate) mod cache;
#[cfg(feature = "http")] #[cfg(feature = "http")]
pub(crate) mod from_pretrained; pub(crate) mod from_pretrained;
#[cfg(feature = "fancy-regex")] #[cfg(all(feature = "fancy-regex", not(feature = "onig")))]
mod fancy; mod fancy;
#[cfg(feature = "fancy-regex")] #[cfg(all(feature = "fancy-regex", not(feature = "onig")))]
pub use fancy::SysRegex; pub use fancy::SysRegex;
#[cfg(feature = "onig")] #[cfg(feature = "onig")]
mod onig; mod onig;
#[cfg(feature = "onig")] #[cfg(feature = "onig")]
pub use crate::utils::onig::SysRegex; pub use crate::utils::onig::SysRegex;
#[cfg(all(feature = "onig", feature = "fancy-regex"))]
compile_error!("Features `onig` and `fancy-regex` are mutually exclusive");
#[cfg(not(any(feature = "onig", feature = "fancy-regex")))] #[cfg(not(any(feature = "onig", feature = "fancy-regex")))]
compile_error!("One of the `onig`, or `fancy-regex` features must be enabled"); compile_error!("One of the `onig`, or `fancy-regex` features must be enabled");