mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-23 00:35:35 +00:00
Aho corasick version for many added tokens. (#871)
* Aho corasick version. * Remove test file. * Compile on `stable`.
This commit is contained in:
1
bindings/node/native/Cargo.lock
generated
1
bindings/node/native/Cargo.lock
generated
@ -1670,6 +1670,7 @@ checksum = "cda74da7e1a664f795bb1f8a87ec406fb89a02522cf6e50620d016add6dbbf5c"
|
||||
name = "tokenizers"
|
||||
version = "0.11.0"
|
||||
dependencies = [
|
||||
"aho-corasick",
|
||||
"cached-path",
|
||||
"clap",
|
||||
"derive_builder",
|
||||
|
1
bindings/python/Cargo.lock
generated
1
bindings/python/Cargo.lock
generated
@ -1745,6 +1745,7 @@ checksum = "cda74da7e1a664f795bb1f8a87ec406fb89a02522cf6e50620d016add6dbbf5c"
|
||||
name = "tokenizers"
|
||||
version = "0.11.0"
|
||||
dependencies = [
|
||||
"aho-corasick",
|
||||
"cached-path",
|
||||
"clap",
|
||||
"derive_builder",
|
||||
|
@ -56,6 +56,7 @@ spm_precompiled = "0.1"
|
||||
dirs = "3.0"
|
||||
reqwest = { version = "0.11", optional = true }
|
||||
cached-path = { version = "0.5", optional = true }
|
||||
aho-corasick = "0.7"
|
||||
|
||||
[features]
|
||||
default = ["progressbar", "http"]
|
||||
|
16
tokenizers/examples/serialization.rs
Normal file
16
tokenizers/examples/serialization.rs
Normal file
@ -0,0 +1,16 @@
|
||||
use tokenizers::models::wordpiece::WordPiece;
|
||||
use tokenizers::{AddedToken, Tokenizer};
|
||||
|
||||
fn main() {
|
||||
let mut tokenizer = Tokenizer::new(WordPiece::default());
|
||||
|
||||
let tokens: Vec<_> = (0..4000)
|
||||
.map(|i| AddedToken::from(format!("[SPECIAL_{}]", i), false))
|
||||
.collect();
|
||||
tokenizer.add_tokens(&tokens);
|
||||
tokenizer.save("_tok.json", false).unwrap();
|
||||
let start = std::time::Instant::now();
|
||||
let _tok = Tokenizer::from_file("_tok.json").unwrap();
|
||||
println!("Took {:?}", start.elapsed());
|
||||
std::fs::remove_file("_tok.json").unwrap();
|
||||
}
|
@ -1,6 +1,7 @@
|
||||
use super::{
|
||||
normalizer::Range, Model, NormalizedString, Normalizer, Offsets, PreTokenizedString, Token,
|
||||
};
|
||||
use aho_corasick::{AhoCorasick, AhoCorasickBuilder, MatchKind};
|
||||
use serde::{ser::SerializeSeq, Deserialize, Serialize, Serializer};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
@ -61,52 +62,6 @@ impl AddedToken {
|
||||
self.normalized = normalized;
|
||||
self
|
||||
}
|
||||
/// Retrive the pattern built for this token, according to all the specified parameters.
|
||||
pub fn get_pattern<N: Normalizer>(&self, normalizer: Option<&N>) -> String {
|
||||
let mut r = if self.single_word {
|
||||
let first_b = self
|
||||
.content
|
||||
.chars()
|
||||
.next()
|
||||
.map(|c| {
|
||||
if regex_syntax::is_word_character(c) {
|
||||
r"\b"
|
||||
} else {
|
||||
""
|
||||
}
|
||||
})
|
||||
.unwrap();
|
||||
let last_b = self
|
||||
.content
|
||||
.chars()
|
||||
.last()
|
||||
.map(|c| {
|
||||
if regex_syntax::is_word_character(c) {
|
||||
r"\b"
|
||||
} else {
|
||||
""
|
||||
}
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
// Normalize the content
|
||||
let mut content = NormalizedString::from(self.content.as_ref());
|
||||
normalizer.map(|n| n.normalize(&mut content));
|
||||
format!(r"{}{}{}", first_b, regex::escape(content.get()), last_b)
|
||||
} else {
|
||||
regex::escape(&self.content)
|
||||
};
|
||||
|
||||
if self.lstrip && self.rstrip {
|
||||
r = format!(r"(\s)?{}(\s)?", r);
|
||||
} else if self.lstrip {
|
||||
r = format!(r"(\s)?{}", r);
|
||||
} else if self.rstrip {
|
||||
r = format!(r"{}(\s)?", r);
|
||||
}
|
||||
|
||||
r
|
||||
}
|
||||
}
|
||||
impl Default for AddedToken {
|
||||
fn default() -> Self {
|
||||
@ -133,7 +88,7 @@ impl std::cmp::PartialEq for AddedToken {
|
||||
}
|
||||
impl std::cmp::Eq for AddedToken {}
|
||||
|
||||
type MatchingSet = (regex::RegexSet, Vec<u32>);
|
||||
type MatchingSet = (AhoCorasick, Vec<u32>);
|
||||
|
||||
///
|
||||
/// A vocabulary built on top of the Model
|
||||
@ -169,21 +124,27 @@ pub(super) struct AddedVocabulary {
|
||||
special_tokens_set: HashSet<String>,
|
||||
|
||||
/// A RegexSet containing all the non-normalized patterns used to split on AddedTokens
|
||||
split_re: MatchingSet,
|
||||
split_trie: MatchingSet,
|
||||
/// A RegexSet containing all the normalized patterns used to split on AddedTokens
|
||||
split_normalized_re: MatchingSet,
|
||||
split_normalized_trie: MatchingSet,
|
||||
}
|
||||
|
||||
impl AddedVocabulary {
|
||||
pub fn new() -> Self {
|
||||
let trie = AhoCorasickBuilder::new()
|
||||
.match_kind(MatchKind::LeftmostLongest)
|
||||
.build::<_, &&[u8]>(&[]);
|
||||
let normalized_trie = AhoCorasickBuilder::new()
|
||||
.match_kind(MatchKind::LeftmostLongest)
|
||||
.build::<_, &&[u8]>(&[]);
|
||||
Self {
|
||||
added_tokens_map: HashMap::new(),
|
||||
added_tokens_map_r: HashMap::new(),
|
||||
added_tokens: vec![],
|
||||
special_tokens: vec![],
|
||||
special_tokens_set: HashSet::new(),
|
||||
split_re: (regex::RegexSet::new::<_, &&str>(&[]).unwrap(), vec![]),
|
||||
split_normalized_re: (regex::RegexSet::new::<_, &&str>(&[]).unwrap(), vec![]),
|
||||
split_trie: (trie, vec![]),
|
||||
split_normalized_trie: (normalized_trie, vec![]),
|
||||
}
|
||||
}
|
||||
|
||||
@ -296,16 +257,26 @@ impl AddedVocabulary {
|
||||
.partition(|(token, _)| token.normalized);
|
||||
|
||||
let (tokens, ids): (Vec<&AddedToken>, Vec<u32>) = non_normalized.into_iter().unzip();
|
||||
self.split_re = (
|
||||
regex::RegexSet::new(tokens.iter().map(|t| t.get_pattern(normalizer))).unwrap(),
|
||||
ids,
|
||||
);
|
||||
let trie = AhoCorasickBuilder::new()
|
||||
.match_kind(MatchKind::LeftmostLongest)
|
||||
.build(tokens.iter().map(|token| &token.content));
|
||||
self.split_trie = (trie, ids);
|
||||
|
||||
let (tokens, ids): (Vec<&AddedToken>, Vec<u32>) = normalized.into_iter().unzip();
|
||||
self.split_normalized_re = (
|
||||
regex::RegexSet::new(tokens.iter().map(|t| t.get_pattern(normalizer))).unwrap(),
|
||||
ids,
|
||||
);
|
||||
let (ntokens, nids): (Vec<&AddedToken>, Vec<u32>) = normalized.into_iter().unzip();
|
||||
let patterns: Vec<_> = ntokens
|
||||
.iter()
|
||||
.map(|token| {
|
||||
let mut content = NormalizedString::from(token.content.as_ref());
|
||||
if let Some(n) = normalizer {
|
||||
n.normalize(&mut content).unwrap();
|
||||
}
|
||||
content
|
||||
})
|
||||
.collect();
|
||||
let normalized_trie = AhoCorasickBuilder::new()
|
||||
.match_kind(MatchKind::LeftmostLongest)
|
||||
.build(patterns.iter().map(|content| content.get()));
|
||||
self.split_normalized_trie = (normalized_trie, nids);
|
||||
}
|
||||
|
||||
/// Find any AddedToken in the given sentence, using the provided MatchingSet.
|
||||
@ -321,81 +292,49 @@ impl AddedVocabulary {
|
||||
return vec![(None, (0, 0))];
|
||||
}
|
||||
|
||||
let mut matches = split_re
|
||||
.0
|
||||
.matches(sentence)
|
||||
.into_iter()
|
||||
.flat_map(|idx| {
|
||||
regex::Regex::new(&split_re.0.patterns()[idx])
|
||||
.unwrap()
|
||||
.find_iter(sentence)
|
||||
.map(|m| (idx, (m.start(), m.end())))
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let mut start_offset = 0;
|
||||
let mut splits = vec![];
|
||||
|
||||
// We sort all the matches by their start and then by their pattern id
|
||||
matches.sort_by(
|
||||
|(idxa, (sa, _)), (idxb, (sb, _))| {
|
||||
if sa != sb {
|
||||
sa.cmp(sb)
|
||||
} else {
|
||||
idxa.cmp(idxb)
|
||||
}
|
||||
},
|
||||
);
|
||||
for mat in split_re.0.find_iter(sentence) {
|
||||
let mut start = mat.start();
|
||||
let mut stop = mat.end();
|
||||
let aho_id = mat.pattern();
|
||||
let id = split_re.1[aho_id];
|
||||
let added_token = &self.added_tokens_map_r.get(&id).unwrap();
|
||||
if added_token.single_word {
|
||||
let start_space = start == 0 || sentence.as_bytes().get(start - 1) == Some(&b' ');
|
||||
let stop_space =
|
||||
stop == sentence.len() || sentence.as_bytes().get(stop + 1) == Some(&b' ');
|
||||
|
||||
// Select the matches (if some are overlapping) we want to keep
|
||||
let mut i = 0;
|
||||
let mut current_offset = 0;
|
||||
let mut splits = Vec::with_capacity(matches.len());
|
||||
while i < matches.len() {
|
||||
let (idx, (start, end)) = matches[i];
|
||||
|
||||
// current match is before the currentt offset, let's skip it
|
||||
if start < current_offset {
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Find out if we have overlapping neighbors. If so, we keep the one with the lowest
|
||||
// idx, and apply it, then continue. All others will be skipped since `current_offset`
|
||||
// will have been increased
|
||||
if i + 1 < matches.len() {
|
||||
if let Some((idx, (s, e))) = matches[i..]
|
||||
.iter()
|
||||
.take_while(|(_, (s, e))| *s < end && start < *e)
|
||||
.min() // Order on idx first
|
||||
.copied()
|
||||
{
|
||||
splits.push((idx, (s, e)));
|
||||
current_offset = e;
|
||||
i += 1;
|
||||
if !stop_space || !start_space {
|
||||
// Discard not single word
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// We didn't find overlapping neighbors, apply ourself
|
||||
splits.push((idx, (start, end)));
|
||||
current_offset = end;
|
||||
i += 1;
|
||||
}
|
||||
|
||||
// We also insert the splits that are inbetween the added tokens, to split the entire string
|
||||
let mut start_offset = 0;
|
||||
let mut splits = splits
|
||||
.into_iter()
|
||||
.flat_map(|(idx, (start, end))| {
|
||||
let mut splits = vec![];
|
||||
if start_offset < start {
|
||||
splits.push((None, (start_offset, start)));
|
||||
if added_token.lstrip {
|
||||
while start > 0 {
|
||||
if let Some(b' ') = sentence.as_bytes().get(start - 1) {
|
||||
start -= 1;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
splits.push((Some(split_re.1[idx] as u32), (start, end)));
|
||||
start_offset = end;
|
||||
|
||||
splits
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
}
|
||||
if added_token.rstrip {
|
||||
while stop < sentence.len() {
|
||||
if let Some(b' ') = sentence.as_bytes().get(stop) {
|
||||
stop += 1;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if start_offset < start {
|
||||
splits.push((None, (start_offset, start)));
|
||||
}
|
||||
splits.push((Some(id), (start, stop)));
|
||||
start_offset = stop;
|
||||
}
|
||||
|
||||
let total_byte_len = sentence.len();
|
||||
if start_offset != total_byte_len {
|
||||
@ -445,14 +384,14 @@ impl AddedVocabulary {
|
||||
|
||||
// 1. We extract all the non-normalized tokens from the non-normalized string
|
||||
pretokenized
|
||||
.split(|_, sequence| Ok(self.split_with_indices(sequence, &self.split_re)))
|
||||
.split(|_, sequence| Ok(self.split_with_indices(sequence, &self.split_trie)))
|
||||
.expect("AddedVocabulary bad split");
|
||||
|
||||
// 2. Then extract the normalized tokens from the normalized pieces of the string
|
||||
pretokenized
|
||||
.split(|_, mut sequence| {
|
||||
normalizer.map(|n| n.normalize(&mut sequence));
|
||||
Ok(self.split_with_indices(sequence, &self.split_normalized_re))
|
||||
Ok(self.split_with_indices(sequence, &self.split_normalized_trie))
|
||||
})
|
||||
.expect("AddedVocabulary bad split");
|
||||
|
||||
@ -756,7 +695,7 @@ mod tests {
|
||||
#[test]
|
||||
fn empty_matches() {
|
||||
let vocab = AddedVocabulary::new();
|
||||
let matches = vocab.find_matches("", &vocab.split_re);
|
||||
let matches = vocab.find_matches("", &vocab.split_trie);
|
||||
assert_eq!(matches, vec![(None, (0, 0))]);
|
||||
}
|
||||
}
|
||||
|
@ -116,5 +116,12 @@ fn overlapping_tokens() {
|
||||
|
||||
let output = tokenizer.encode(input, false).unwrap();
|
||||
|
||||
assert_eq!(output.get_tokens(), &["I", "Ġl", "ike", "Ġda", "nci", "ng"]);
|
||||
// Breaking change but following `transformers` breaking change.
|
||||
// This behavior is deemed not used in practice:
|
||||
// https://github.com/huggingface/transformers/pull/13220
|
||||
// Order does NOT matter. (We could make it work again but the trie
|
||||
// would need to keep insertion order too)
|
||||
//
|
||||
// assert_eq!(output.get_tokens(), &["I", "Ġlike", "Ġda", "nci", "ng"]);
|
||||
assert_eq!(output.get_tokens(), &["I", "Ġl", "ike", "Ġ", "danc", "ing"]);
|
||||
}
|
||||
|
Reference in New Issue
Block a user