mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-23 00:35:35 +00:00
Aho corasick version for many added tokens. (#871)
* Aho corasick version. * Remove test file. * Compile on `stable`.
This commit is contained in:
1
bindings/node/native/Cargo.lock
generated
1
bindings/node/native/Cargo.lock
generated
@ -1670,6 +1670,7 @@ checksum = "cda74da7e1a664f795bb1f8a87ec406fb89a02522cf6e50620d016add6dbbf5c"
|
|||||||
name = "tokenizers"
|
name = "tokenizers"
|
||||||
version = "0.11.0"
|
version = "0.11.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"aho-corasick",
|
||||||
"cached-path",
|
"cached-path",
|
||||||
"clap",
|
"clap",
|
||||||
"derive_builder",
|
"derive_builder",
|
||||||
|
1
bindings/python/Cargo.lock
generated
1
bindings/python/Cargo.lock
generated
@ -1745,6 +1745,7 @@ checksum = "cda74da7e1a664f795bb1f8a87ec406fb89a02522cf6e50620d016add6dbbf5c"
|
|||||||
name = "tokenizers"
|
name = "tokenizers"
|
||||||
version = "0.11.0"
|
version = "0.11.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"aho-corasick",
|
||||||
"cached-path",
|
"cached-path",
|
||||||
"clap",
|
"clap",
|
||||||
"derive_builder",
|
"derive_builder",
|
||||||
|
@ -56,6 +56,7 @@ spm_precompiled = "0.1"
|
|||||||
dirs = "3.0"
|
dirs = "3.0"
|
||||||
reqwest = { version = "0.11", optional = true }
|
reqwest = { version = "0.11", optional = true }
|
||||||
cached-path = { version = "0.5", optional = true }
|
cached-path = { version = "0.5", optional = true }
|
||||||
|
aho-corasick = "0.7"
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = ["progressbar", "http"]
|
default = ["progressbar", "http"]
|
||||||
|
16
tokenizers/examples/serialization.rs
Normal file
16
tokenizers/examples/serialization.rs
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
use tokenizers::models::wordpiece::WordPiece;
|
||||||
|
use tokenizers::{AddedToken, Tokenizer};
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
let mut tokenizer = Tokenizer::new(WordPiece::default());
|
||||||
|
|
||||||
|
let tokens: Vec<_> = (0..4000)
|
||||||
|
.map(|i| AddedToken::from(format!("[SPECIAL_{}]", i), false))
|
||||||
|
.collect();
|
||||||
|
tokenizer.add_tokens(&tokens);
|
||||||
|
tokenizer.save("_tok.json", false).unwrap();
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let _tok = Tokenizer::from_file("_tok.json").unwrap();
|
||||||
|
println!("Took {:?}", start.elapsed());
|
||||||
|
std::fs::remove_file("_tok.json").unwrap();
|
||||||
|
}
|
@ -1,6 +1,7 @@
|
|||||||
use super::{
|
use super::{
|
||||||
normalizer::Range, Model, NormalizedString, Normalizer, Offsets, PreTokenizedString, Token,
|
normalizer::Range, Model, NormalizedString, Normalizer, Offsets, PreTokenizedString, Token,
|
||||||
};
|
};
|
||||||
|
use aho_corasick::{AhoCorasick, AhoCorasickBuilder, MatchKind};
|
||||||
use serde::{ser::SerializeSeq, Deserialize, Serialize, Serializer};
|
use serde::{ser::SerializeSeq, Deserialize, Serialize, Serializer};
|
||||||
use std::collections::{HashMap, HashSet};
|
use std::collections::{HashMap, HashSet};
|
||||||
|
|
||||||
@ -61,52 +62,6 @@ impl AddedToken {
|
|||||||
self.normalized = normalized;
|
self.normalized = normalized;
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
/// Retrive the pattern built for this token, according to all the specified parameters.
|
|
||||||
pub fn get_pattern<N: Normalizer>(&self, normalizer: Option<&N>) -> String {
|
|
||||||
let mut r = if self.single_word {
|
|
||||||
let first_b = self
|
|
||||||
.content
|
|
||||||
.chars()
|
|
||||||
.next()
|
|
||||||
.map(|c| {
|
|
||||||
if regex_syntax::is_word_character(c) {
|
|
||||||
r"\b"
|
|
||||||
} else {
|
|
||||||
""
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.unwrap();
|
|
||||||
let last_b = self
|
|
||||||
.content
|
|
||||||
.chars()
|
|
||||||
.last()
|
|
||||||
.map(|c| {
|
|
||||||
if regex_syntax::is_word_character(c) {
|
|
||||||
r"\b"
|
|
||||||
} else {
|
|
||||||
""
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
// Normalize the content
|
|
||||||
let mut content = NormalizedString::from(self.content.as_ref());
|
|
||||||
normalizer.map(|n| n.normalize(&mut content));
|
|
||||||
format!(r"{}{}{}", first_b, regex::escape(content.get()), last_b)
|
|
||||||
} else {
|
|
||||||
regex::escape(&self.content)
|
|
||||||
};
|
|
||||||
|
|
||||||
if self.lstrip && self.rstrip {
|
|
||||||
r = format!(r"(\s)?{}(\s)?", r);
|
|
||||||
} else if self.lstrip {
|
|
||||||
r = format!(r"(\s)?{}", r);
|
|
||||||
} else if self.rstrip {
|
|
||||||
r = format!(r"{}(\s)?", r);
|
|
||||||
}
|
|
||||||
|
|
||||||
r
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
impl Default for AddedToken {
|
impl Default for AddedToken {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
@ -133,7 +88,7 @@ impl std::cmp::PartialEq for AddedToken {
|
|||||||
}
|
}
|
||||||
impl std::cmp::Eq for AddedToken {}
|
impl std::cmp::Eq for AddedToken {}
|
||||||
|
|
||||||
type MatchingSet = (regex::RegexSet, Vec<u32>);
|
type MatchingSet = (AhoCorasick, Vec<u32>);
|
||||||
|
|
||||||
///
|
///
|
||||||
/// A vocabulary built on top of the Model
|
/// A vocabulary built on top of the Model
|
||||||
@ -169,21 +124,27 @@ pub(super) struct AddedVocabulary {
|
|||||||
special_tokens_set: HashSet<String>,
|
special_tokens_set: HashSet<String>,
|
||||||
|
|
||||||
/// A RegexSet containing all the non-normalized patterns used to split on AddedTokens
|
/// A RegexSet containing all the non-normalized patterns used to split on AddedTokens
|
||||||
split_re: MatchingSet,
|
split_trie: MatchingSet,
|
||||||
/// A RegexSet containing all the normalized patterns used to split on AddedTokens
|
/// A RegexSet containing all the normalized patterns used to split on AddedTokens
|
||||||
split_normalized_re: MatchingSet,
|
split_normalized_trie: MatchingSet,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AddedVocabulary {
|
impl AddedVocabulary {
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
|
let trie = AhoCorasickBuilder::new()
|
||||||
|
.match_kind(MatchKind::LeftmostLongest)
|
||||||
|
.build::<_, &&[u8]>(&[]);
|
||||||
|
let normalized_trie = AhoCorasickBuilder::new()
|
||||||
|
.match_kind(MatchKind::LeftmostLongest)
|
||||||
|
.build::<_, &&[u8]>(&[]);
|
||||||
Self {
|
Self {
|
||||||
added_tokens_map: HashMap::new(),
|
added_tokens_map: HashMap::new(),
|
||||||
added_tokens_map_r: HashMap::new(),
|
added_tokens_map_r: HashMap::new(),
|
||||||
added_tokens: vec![],
|
added_tokens: vec![],
|
||||||
special_tokens: vec![],
|
special_tokens: vec![],
|
||||||
special_tokens_set: HashSet::new(),
|
special_tokens_set: HashSet::new(),
|
||||||
split_re: (regex::RegexSet::new::<_, &&str>(&[]).unwrap(), vec![]),
|
split_trie: (trie, vec![]),
|
||||||
split_normalized_re: (regex::RegexSet::new::<_, &&str>(&[]).unwrap(), vec![]),
|
split_normalized_trie: (normalized_trie, vec![]),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -296,16 +257,26 @@ impl AddedVocabulary {
|
|||||||
.partition(|(token, _)| token.normalized);
|
.partition(|(token, _)| token.normalized);
|
||||||
|
|
||||||
let (tokens, ids): (Vec<&AddedToken>, Vec<u32>) = non_normalized.into_iter().unzip();
|
let (tokens, ids): (Vec<&AddedToken>, Vec<u32>) = non_normalized.into_iter().unzip();
|
||||||
self.split_re = (
|
let trie = AhoCorasickBuilder::new()
|
||||||
regex::RegexSet::new(tokens.iter().map(|t| t.get_pattern(normalizer))).unwrap(),
|
.match_kind(MatchKind::LeftmostLongest)
|
||||||
ids,
|
.build(tokens.iter().map(|token| &token.content));
|
||||||
);
|
self.split_trie = (trie, ids);
|
||||||
|
|
||||||
let (tokens, ids): (Vec<&AddedToken>, Vec<u32>) = normalized.into_iter().unzip();
|
let (ntokens, nids): (Vec<&AddedToken>, Vec<u32>) = normalized.into_iter().unzip();
|
||||||
self.split_normalized_re = (
|
let patterns: Vec<_> = ntokens
|
||||||
regex::RegexSet::new(tokens.iter().map(|t| t.get_pattern(normalizer))).unwrap(),
|
.iter()
|
||||||
ids,
|
.map(|token| {
|
||||||
);
|
let mut content = NormalizedString::from(token.content.as_ref());
|
||||||
|
if let Some(n) = normalizer {
|
||||||
|
n.normalize(&mut content).unwrap();
|
||||||
|
}
|
||||||
|
content
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
let normalized_trie = AhoCorasickBuilder::new()
|
||||||
|
.match_kind(MatchKind::LeftmostLongest)
|
||||||
|
.build(patterns.iter().map(|content| content.get()));
|
||||||
|
self.split_normalized_trie = (normalized_trie, nids);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Find any AddedToken in the given sentence, using the provided MatchingSet.
|
/// Find any AddedToken in the given sentence, using the provided MatchingSet.
|
||||||
@ -321,81 +292,49 @@ impl AddedVocabulary {
|
|||||||
return vec![(None, (0, 0))];
|
return vec![(None, (0, 0))];
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut matches = split_re
|
|
||||||
.0
|
|
||||||
.matches(sentence)
|
|
||||||
.into_iter()
|
|
||||||
.flat_map(|idx| {
|
|
||||||
regex::Regex::new(&split_re.0.patterns()[idx])
|
|
||||||
.unwrap()
|
|
||||||
.find_iter(sentence)
|
|
||||||
.map(|m| (idx, (m.start(), m.end())))
|
|
||||||
.collect::<Vec<_>>()
|
|
||||||
})
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
|
|
||||||
// We sort all the matches by their start and then by their pattern id
|
|
||||||
matches.sort_by(
|
|
||||||
|(idxa, (sa, _)), (idxb, (sb, _))| {
|
|
||||||
if sa != sb {
|
|
||||||
sa.cmp(sb)
|
|
||||||
} else {
|
|
||||||
idxa.cmp(idxb)
|
|
||||||
}
|
|
||||||
},
|
|
||||||
);
|
|
||||||
|
|
||||||
// Select the matches (if some are overlapping) we want to keep
|
|
||||||
let mut i = 0;
|
|
||||||
let mut current_offset = 0;
|
|
||||||
let mut splits = Vec::with_capacity(matches.len());
|
|
||||||
while i < matches.len() {
|
|
||||||
let (idx, (start, end)) = matches[i];
|
|
||||||
|
|
||||||
// current match is before the currentt offset, let's skip it
|
|
||||||
if start < current_offset {
|
|
||||||
i += 1;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Find out if we have overlapping neighbors. If so, we keep the one with the lowest
|
|
||||||
// idx, and apply it, then continue. All others will be skipped since `current_offset`
|
|
||||||
// will have been increased
|
|
||||||
if i + 1 < matches.len() {
|
|
||||||
if let Some((idx, (s, e))) = matches[i..]
|
|
||||||
.iter()
|
|
||||||
.take_while(|(_, (s, e))| *s < end && start < *e)
|
|
||||||
.min() // Order on idx first
|
|
||||||
.copied()
|
|
||||||
{
|
|
||||||
splits.push((idx, (s, e)));
|
|
||||||
current_offset = e;
|
|
||||||
i += 1;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// We didn't find overlapping neighbors, apply ourself
|
|
||||||
splits.push((idx, (start, end)));
|
|
||||||
current_offset = end;
|
|
||||||
i += 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// We also insert the splits that are inbetween the added tokens, to split the entire string
|
|
||||||
let mut start_offset = 0;
|
let mut start_offset = 0;
|
||||||
let mut splits = splits
|
|
||||||
.into_iter()
|
|
||||||
.flat_map(|(idx, (start, end))| {
|
|
||||||
let mut splits = vec![];
|
let mut splits = vec![];
|
||||||
|
|
||||||
|
for mat in split_re.0.find_iter(sentence) {
|
||||||
|
let mut start = mat.start();
|
||||||
|
let mut stop = mat.end();
|
||||||
|
let aho_id = mat.pattern();
|
||||||
|
let id = split_re.1[aho_id];
|
||||||
|
let added_token = &self.added_tokens_map_r.get(&id).unwrap();
|
||||||
|
if added_token.single_word {
|
||||||
|
let start_space = start == 0 || sentence.as_bytes().get(start - 1) == Some(&b' ');
|
||||||
|
let stop_space =
|
||||||
|
stop == sentence.len() || sentence.as_bytes().get(stop + 1) == Some(&b' ');
|
||||||
|
|
||||||
|
if !stop_space || !start_space {
|
||||||
|
// Discard not single word
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if added_token.lstrip {
|
||||||
|
while start > 0 {
|
||||||
|
if let Some(b' ') = sentence.as_bytes().get(start - 1) {
|
||||||
|
start -= 1;
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if added_token.rstrip {
|
||||||
|
while stop < sentence.len() {
|
||||||
|
if let Some(b' ') = sentence.as_bytes().get(stop) {
|
||||||
|
stop += 1;
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
if start_offset < start {
|
if start_offset < start {
|
||||||
splits.push((None, (start_offset, start)));
|
splits.push((None, (start_offset, start)));
|
||||||
}
|
}
|
||||||
splits.push((Some(split_re.1[idx] as u32), (start, end)));
|
splits.push((Some(id), (start, stop)));
|
||||||
start_offset = end;
|
start_offset = stop;
|
||||||
|
}
|
||||||
splits
|
|
||||||
})
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
|
|
||||||
let total_byte_len = sentence.len();
|
let total_byte_len = sentence.len();
|
||||||
if start_offset != total_byte_len {
|
if start_offset != total_byte_len {
|
||||||
@ -445,14 +384,14 @@ impl AddedVocabulary {
|
|||||||
|
|
||||||
// 1. We extract all the non-normalized tokens from the non-normalized string
|
// 1. We extract all the non-normalized tokens from the non-normalized string
|
||||||
pretokenized
|
pretokenized
|
||||||
.split(|_, sequence| Ok(self.split_with_indices(sequence, &self.split_re)))
|
.split(|_, sequence| Ok(self.split_with_indices(sequence, &self.split_trie)))
|
||||||
.expect("AddedVocabulary bad split");
|
.expect("AddedVocabulary bad split");
|
||||||
|
|
||||||
// 2. Then extract the normalized tokens from the normalized pieces of the string
|
// 2. Then extract the normalized tokens from the normalized pieces of the string
|
||||||
pretokenized
|
pretokenized
|
||||||
.split(|_, mut sequence| {
|
.split(|_, mut sequence| {
|
||||||
normalizer.map(|n| n.normalize(&mut sequence));
|
normalizer.map(|n| n.normalize(&mut sequence));
|
||||||
Ok(self.split_with_indices(sequence, &self.split_normalized_re))
|
Ok(self.split_with_indices(sequence, &self.split_normalized_trie))
|
||||||
})
|
})
|
||||||
.expect("AddedVocabulary bad split");
|
.expect("AddedVocabulary bad split");
|
||||||
|
|
||||||
@ -756,7 +695,7 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn empty_matches() {
|
fn empty_matches() {
|
||||||
let vocab = AddedVocabulary::new();
|
let vocab = AddedVocabulary::new();
|
||||||
let matches = vocab.find_matches("", &vocab.split_re);
|
let matches = vocab.find_matches("", &vocab.split_trie);
|
||||||
assert_eq!(matches, vec![(None, (0, 0))]);
|
assert_eq!(matches, vec![(None, (0, 0))]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -116,5 +116,12 @@ fn overlapping_tokens() {
|
|||||||
|
|
||||||
let output = tokenizer.encode(input, false).unwrap();
|
let output = tokenizer.encode(input, false).unwrap();
|
||||||
|
|
||||||
assert_eq!(output.get_tokens(), &["I", "Ġl", "ike", "Ġda", "nci", "ng"]);
|
// Breaking change but following `transformers` breaking change.
|
||||||
|
// This behavior is deemed not used in practice:
|
||||||
|
// https://github.com/huggingface/transformers/pull/13220
|
||||||
|
// Order does NOT matter. (We could make it work again but the trie
|
||||||
|
// would need to keep insertion order too)
|
||||||
|
//
|
||||||
|
// assert_eq!(output.get_tokens(), &["I", "Ġlike", "Ġda", "nci", "ng"]);
|
||||||
|
assert_eq!(output.get_tokens(), &["I", "Ġl", "ike", "Ġ", "danc", "ing"]);
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user