diff --git a/bindings/node/native/Cargo.lock b/bindings/node/native/Cargo.lock index aecb0982..f930d92c 100644 --- a/bindings/node/native/Cargo.lock +++ b/bindings/node/native/Cargo.lock @@ -1670,6 +1670,7 @@ checksum = "cda74da7e1a664f795bb1f8a87ec406fb89a02522cf6e50620d016add6dbbf5c" name = "tokenizers" version = "0.11.0" dependencies = [ + "aho-corasick", "cached-path", "clap", "derive_builder", diff --git a/bindings/python/Cargo.lock b/bindings/python/Cargo.lock index 4f7aee56..07217488 100644 --- a/bindings/python/Cargo.lock +++ b/bindings/python/Cargo.lock @@ -1745,6 +1745,7 @@ checksum = "cda74da7e1a664f795bb1f8a87ec406fb89a02522cf6e50620d016add6dbbf5c" name = "tokenizers" version = "0.11.0" dependencies = [ + "aho-corasick", "cached-path", "clap", "derive_builder", diff --git a/tokenizers/Cargo.toml b/tokenizers/Cargo.toml index 69bcf7d9..c3d8a78f 100644 --- a/tokenizers/Cargo.toml +++ b/tokenizers/Cargo.toml @@ -56,6 +56,7 @@ spm_precompiled = "0.1" dirs = "3.0" reqwest = { version = "0.11", optional = true } cached-path = { version = "0.5", optional = true } +aho-corasick = "0.7" [features] default = ["progressbar", "http"] diff --git a/tokenizers/examples/serialization.rs b/tokenizers/examples/serialization.rs new file mode 100644 index 00000000..6d9a5f42 --- /dev/null +++ b/tokenizers/examples/serialization.rs @@ -0,0 +1,16 @@ +use tokenizers::models::wordpiece::WordPiece; +use tokenizers::{AddedToken, Tokenizer}; + +fn main() { + let mut tokenizer = Tokenizer::new(WordPiece::default()); + + let tokens: Vec<_> = (0..4000) + .map(|i| AddedToken::from(format!("[SPECIAL_{}]", i), false)) + .collect(); + tokenizer.add_tokens(&tokens); + tokenizer.save("_tok.json", false).unwrap(); + let start = std::time::Instant::now(); + let _tok = Tokenizer::from_file("_tok.json").unwrap(); + println!("Took {:?}", start.elapsed()); + std::fs::remove_file("_tok.json").unwrap(); +} diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs index 444c9f61..72251f63 100644 --- a/tokenizers/src/tokenizer/added_vocabulary.rs +++ b/tokenizers/src/tokenizer/added_vocabulary.rs @@ -1,6 +1,7 @@ use super::{ normalizer::Range, Model, NormalizedString, Normalizer, Offsets, PreTokenizedString, Token, }; +use aho_corasick::{AhoCorasick, AhoCorasickBuilder, MatchKind}; use serde::{ser::SerializeSeq, Deserialize, Serialize, Serializer}; use std::collections::{HashMap, HashSet}; @@ -61,52 +62,6 @@ impl AddedToken { self.normalized = normalized; self } - /// Retrive the pattern built for this token, according to all the specified parameters. - pub fn get_pattern(&self, normalizer: Option<&N>) -> String { - let mut r = if self.single_word { - let first_b = self - .content - .chars() - .next() - .map(|c| { - if regex_syntax::is_word_character(c) { - r"\b" - } else { - "" - } - }) - .unwrap(); - let last_b = self - .content - .chars() - .last() - .map(|c| { - if regex_syntax::is_word_character(c) { - r"\b" - } else { - "" - } - }) - .unwrap(); - - // Normalize the content - let mut content = NormalizedString::from(self.content.as_ref()); - normalizer.map(|n| n.normalize(&mut content)); - format!(r"{}{}{}", first_b, regex::escape(content.get()), last_b) - } else { - regex::escape(&self.content) - }; - - if self.lstrip && self.rstrip { - r = format!(r"(\s)?{}(\s)?", r); - } else if self.lstrip { - r = format!(r"(\s)?{}", r); - } else if self.rstrip { - r = format!(r"{}(\s)?", r); - } - - r - } } impl Default for AddedToken { fn default() -> Self { @@ -133,7 +88,7 @@ impl std::cmp::PartialEq for AddedToken { } impl std::cmp::Eq for AddedToken {} -type MatchingSet = (regex::RegexSet, Vec); +type MatchingSet = (AhoCorasick, Vec); /// /// A vocabulary built on top of the Model @@ -169,21 +124,27 @@ pub(super) struct AddedVocabulary { special_tokens_set: HashSet, /// A RegexSet containing all the non-normalized patterns used to split on AddedTokens - split_re: MatchingSet, + split_trie: MatchingSet, /// A RegexSet containing all the normalized patterns used to split on AddedTokens - split_normalized_re: MatchingSet, + split_normalized_trie: MatchingSet, } impl AddedVocabulary { pub fn new() -> Self { + let trie = AhoCorasickBuilder::new() + .match_kind(MatchKind::LeftmostLongest) + .build::<_, &&[u8]>(&[]); + let normalized_trie = AhoCorasickBuilder::new() + .match_kind(MatchKind::LeftmostLongest) + .build::<_, &&[u8]>(&[]); Self { added_tokens_map: HashMap::new(), added_tokens_map_r: HashMap::new(), added_tokens: vec![], special_tokens: vec![], special_tokens_set: HashSet::new(), - split_re: (regex::RegexSet::new::<_, &&str>(&[]).unwrap(), vec![]), - split_normalized_re: (regex::RegexSet::new::<_, &&str>(&[]).unwrap(), vec![]), + split_trie: (trie, vec![]), + split_normalized_trie: (normalized_trie, vec![]), } } @@ -296,16 +257,26 @@ impl AddedVocabulary { .partition(|(token, _)| token.normalized); let (tokens, ids): (Vec<&AddedToken>, Vec) = non_normalized.into_iter().unzip(); - self.split_re = ( - regex::RegexSet::new(tokens.iter().map(|t| t.get_pattern(normalizer))).unwrap(), - ids, - ); + let trie = AhoCorasickBuilder::new() + .match_kind(MatchKind::LeftmostLongest) + .build(tokens.iter().map(|token| &token.content)); + self.split_trie = (trie, ids); - let (tokens, ids): (Vec<&AddedToken>, Vec) = normalized.into_iter().unzip(); - self.split_normalized_re = ( - regex::RegexSet::new(tokens.iter().map(|t| t.get_pattern(normalizer))).unwrap(), - ids, - ); + let (ntokens, nids): (Vec<&AddedToken>, Vec) = normalized.into_iter().unzip(); + let patterns: Vec<_> = ntokens + .iter() + .map(|token| { + let mut content = NormalizedString::from(token.content.as_ref()); + if let Some(n) = normalizer { + n.normalize(&mut content).unwrap(); + } + content + }) + .collect(); + let normalized_trie = AhoCorasickBuilder::new() + .match_kind(MatchKind::LeftmostLongest) + .build(patterns.iter().map(|content| content.get())); + self.split_normalized_trie = (normalized_trie, nids); } /// Find any AddedToken in the given sentence, using the provided MatchingSet. @@ -321,81 +292,49 @@ impl AddedVocabulary { return vec![(None, (0, 0))]; } - let mut matches = split_re - .0 - .matches(sentence) - .into_iter() - .flat_map(|idx| { - regex::Regex::new(&split_re.0.patterns()[idx]) - .unwrap() - .find_iter(sentence) - .map(|m| (idx, (m.start(), m.end()))) - .collect::>() - }) - .collect::>(); + let mut start_offset = 0; + let mut splits = vec![]; - // We sort all the matches by their start and then by their pattern id - matches.sort_by( - |(idxa, (sa, _)), (idxb, (sb, _))| { - if sa != sb { - sa.cmp(sb) - } else { - idxa.cmp(idxb) - } - }, - ); + for mat in split_re.0.find_iter(sentence) { + let mut start = mat.start(); + let mut stop = mat.end(); + let aho_id = mat.pattern(); + let id = split_re.1[aho_id]; + let added_token = &self.added_tokens_map_r.get(&id).unwrap(); + if added_token.single_word { + let start_space = start == 0 || sentence.as_bytes().get(start - 1) == Some(&b' '); + let stop_space = + stop == sentence.len() || sentence.as_bytes().get(stop + 1) == Some(&b' '); - // Select the matches (if some are overlapping) we want to keep - let mut i = 0; - let mut current_offset = 0; - let mut splits = Vec::with_capacity(matches.len()); - while i < matches.len() { - let (idx, (start, end)) = matches[i]; - - // current match is before the currentt offset, let's skip it - if start < current_offset { - i += 1; - continue; - } - - // Find out if we have overlapping neighbors. If so, we keep the one with the lowest - // idx, and apply it, then continue. All others will be skipped since `current_offset` - // will have been increased - if i + 1 < matches.len() { - if let Some((idx, (s, e))) = matches[i..] - .iter() - .take_while(|(_, (s, e))| *s < end && start < *e) - .min() // Order on idx first - .copied() - { - splits.push((idx, (s, e))); - current_offset = e; - i += 1; + if !stop_space || !start_space { + // Discard not single word continue; } } - - // We didn't find overlapping neighbors, apply ourself - splits.push((idx, (start, end))); - current_offset = end; - i += 1; - } - - // We also insert the splits that are inbetween the added tokens, to split the entire string - let mut start_offset = 0; - let mut splits = splits - .into_iter() - .flat_map(|(idx, (start, end))| { - let mut splits = vec![]; - if start_offset < start { - splits.push((None, (start_offset, start))); + if added_token.lstrip { + while start > 0 { + if let Some(b' ') = sentence.as_bytes().get(start - 1) { + start -= 1; + } else { + break; + } } - splits.push((Some(split_re.1[idx] as u32), (start, end))); - start_offset = end; - - splits - }) - .collect::>(); + } + if added_token.rstrip { + while stop < sentence.len() { + if let Some(b' ') = sentence.as_bytes().get(stop) { + stop += 1; + } else { + break; + } + } + } + if start_offset < start { + splits.push((None, (start_offset, start))); + } + splits.push((Some(id), (start, stop))); + start_offset = stop; + } let total_byte_len = sentence.len(); if start_offset != total_byte_len { @@ -445,14 +384,14 @@ impl AddedVocabulary { // 1. We extract all the non-normalized tokens from the non-normalized string pretokenized - .split(|_, sequence| Ok(self.split_with_indices(sequence, &self.split_re))) + .split(|_, sequence| Ok(self.split_with_indices(sequence, &self.split_trie))) .expect("AddedVocabulary bad split"); // 2. Then extract the normalized tokens from the normalized pieces of the string pretokenized .split(|_, mut sequence| { normalizer.map(|n| n.normalize(&mut sequence)); - Ok(self.split_with_indices(sequence, &self.split_normalized_re)) + Ok(self.split_with_indices(sequence, &self.split_normalized_trie)) }) .expect("AddedVocabulary bad split"); @@ -756,7 +695,7 @@ mod tests { #[test] fn empty_matches() { let vocab = AddedVocabulary::new(); - let matches = vocab.find_matches("", &vocab.split_re); + let matches = vocab.find_matches("", &vocab.split_trie); assert_eq!(matches, vec![(None, (0, 0))]); } } diff --git a/tokenizers/tests/added_tokens.rs b/tokenizers/tests/added_tokens.rs index 4d8235cc..6c146779 100644 --- a/tokenizers/tests/added_tokens.rs +++ b/tokenizers/tests/added_tokens.rs @@ -116,5 +116,12 @@ fn overlapping_tokens() { let output = tokenizer.encode(input, false).unwrap(); - assert_eq!(output.get_tokens(), &["I", "Ġl", "ike", "Ġda", "nci", "ng"]); + // Breaking change but following `transformers` breaking change. + // This behavior is deemed not used in practice: + // https://github.com/huggingface/transformers/pull/13220 + // Order does NOT matter. (We could make it work again but the trie + // would need to keep insertion order too) + // + // assert_eq!(output.get_tokens(), &["I", "Ġlike", "Ġda", "nci", "ng"]); + assert_eq!(output.get_tokens(), &["I", "Ġl", "ike", "Ġ", "danc", "ing"]); }