Aho corasick version for many added tokens. (#871)

* Aho corasick version.

* Remove test file.

* Compile on `stable`.
This commit is contained in:
Nicolas Patry
2022-01-06 16:04:51 +01:00
committed by GitHub
parent fb837b4adb
commit 076319d542
6 changed files with 98 additions and 133 deletions

View File

@ -1670,6 +1670,7 @@ checksum = "cda74da7e1a664f795bb1f8a87ec406fb89a02522cf6e50620d016add6dbbf5c"
name = "tokenizers"
version = "0.11.0"
dependencies = [
"aho-corasick",
"cached-path",
"clap",
"derive_builder",

View File

@ -1745,6 +1745,7 @@ checksum = "cda74da7e1a664f795bb1f8a87ec406fb89a02522cf6e50620d016add6dbbf5c"
name = "tokenizers"
version = "0.11.0"
dependencies = [
"aho-corasick",
"cached-path",
"clap",
"derive_builder",

View File

@ -56,6 +56,7 @@ spm_precompiled = "0.1"
dirs = "3.0"
reqwest = { version = "0.11", optional = true }
cached-path = { version = "0.5", optional = true }
aho-corasick = "0.7"
[features]
default = ["progressbar", "http"]

View File

@ -0,0 +1,16 @@
use tokenizers::models::wordpiece::WordPiece;
use tokenizers::{AddedToken, Tokenizer};
fn main() {
let mut tokenizer = Tokenizer::new(WordPiece::default());
let tokens: Vec<_> = (0..4000)
.map(|i| AddedToken::from(format!("[SPECIAL_{}]", i), false))
.collect();
tokenizer.add_tokens(&tokens);
tokenizer.save("_tok.json", false).unwrap();
let start = std::time::Instant::now();
let _tok = Tokenizer::from_file("_tok.json").unwrap();
println!("Took {:?}", start.elapsed());
std::fs::remove_file("_tok.json").unwrap();
}

View File

@ -1,6 +1,7 @@
use super::{
normalizer::Range, Model, NormalizedString, Normalizer, Offsets, PreTokenizedString, Token,
};
use aho_corasick::{AhoCorasick, AhoCorasickBuilder, MatchKind};
use serde::{ser::SerializeSeq, Deserialize, Serialize, Serializer};
use std::collections::{HashMap, HashSet};
@ -61,52 +62,6 @@ impl AddedToken {
self.normalized = normalized;
self
}
/// Retrive the pattern built for this token, according to all the specified parameters.
pub fn get_pattern<N: Normalizer>(&self, normalizer: Option<&N>) -> String {
let mut r = if self.single_word {
let first_b = self
.content
.chars()
.next()
.map(|c| {
if regex_syntax::is_word_character(c) {
r"\b"
} else {
""
}
})
.unwrap();
let last_b = self
.content
.chars()
.last()
.map(|c| {
if regex_syntax::is_word_character(c) {
r"\b"
} else {
""
}
})
.unwrap();
// Normalize the content
let mut content = NormalizedString::from(self.content.as_ref());
normalizer.map(|n| n.normalize(&mut content));
format!(r"{}{}{}", first_b, regex::escape(content.get()), last_b)
} else {
regex::escape(&self.content)
};
if self.lstrip && self.rstrip {
r = format!(r"(\s)?{}(\s)?", r);
} else if self.lstrip {
r = format!(r"(\s)?{}", r);
} else if self.rstrip {
r = format!(r"{}(\s)?", r);
}
r
}
}
impl Default for AddedToken {
fn default() -> Self {
@ -133,7 +88,7 @@ impl std::cmp::PartialEq for AddedToken {
}
impl std::cmp::Eq for AddedToken {}
type MatchingSet = (regex::RegexSet, Vec<u32>);
type MatchingSet = (AhoCorasick, Vec<u32>);
///
/// A vocabulary built on top of the Model
@ -169,21 +124,27 @@ pub(super) struct AddedVocabulary {
special_tokens_set: HashSet<String>,
/// A RegexSet containing all the non-normalized patterns used to split on AddedTokens
split_re: MatchingSet,
split_trie: MatchingSet,
/// A RegexSet containing all the normalized patterns used to split on AddedTokens
split_normalized_re: MatchingSet,
split_normalized_trie: MatchingSet,
}
impl AddedVocabulary {
pub fn new() -> Self {
let trie = AhoCorasickBuilder::new()
.match_kind(MatchKind::LeftmostLongest)
.build::<_, &&[u8]>(&[]);
let normalized_trie = AhoCorasickBuilder::new()
.match_kind(MatchKind::LeftmostLongest)
.build::<_, &&[u8]>(&[]);
Self {
added_tokens_map: HashMap::new(),
added_tokens_map_r: HashMap::new(),
added_tokens: vec![],
special_tokens: vec![],
special_tokens_set: HashSet::new(),
split_re: (regex::RegexSet::new::<_, &&str>(&[]).unwrap(), vec![]),
split_normalized_re: (regex::RegexSet::new::<_, &&str>(&[]).unwrap(), vec![]),
split_trie: (trie, vec![]),
split_normalized_trie: (normalized_trie, vec![]),
}
}
@ -296,16 +257,26 @@ impl AddedVocabulary {
.partition(|(token, _)| token.normalized);
let (tokens, ids): (Vec<&AddedToken>, Vec<u32>) = non_normalized.into_iter().unzip();
self.split_re = (
regex::RegexSet::new(tokens.iter().map(|t| t.get_pattern(normalizer))).unwrap(),
ids,
);
let trie = AhoCorasickBuilder::new()
.match_kind(MatchKind::LeftmostLongest)
.build(tokens.iter().map(|token| &token.content));
self.split_trie = (trie, ids);
let (tokens, ids): (Vec<&AddedToken>, Vec<u32>) = normalized.into_iter().unzip();
self.split_normalized_re = (
regex::RegexSet::new(tokens.iter().map(|t| t.get_pattern(normalizer))).unwrap(),
ids,
);
let (ntokens, nids): (Vec<&AddedToken>, Vec<u32>) = normalized.into_iter().unzip();
let patterns: Vec<_> = ntokens
.iter()
.map(|token| {
let mut content = NormalizedString::from(token.content.as_ref());
if let Some(n) = normalizer {
n.normalize(&mut content).unwrap();
}
content
})
.collect();
let normalized_trie = AhoCorasickBuilder::new()
.match_kind(MatchKind::LeftmostLongest)
.build(patterns.iter().map(|content| content.get()));
self.split_normalized_trie = (normalized_trie, nids);
}
/// Find any AddedToken in the given sentence, using the provided MatchingSet.
@ -321,81 +292,49 @@ impl AddedVocabulary {
return vec![(None, (0, 0))];
}
let mut matches = split_re
.0
.matches(sentence)
.into_iter()
.flat_map(|idx| {
regex::Regex::new(&split_re.0.patterns()[idx])
.unwrap()
.find_iter(sentence)
.map(|m| (idx, (m.start(), m.end())))
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();
// We sort all the matches by their start and then by their pattern id
matches.sort_by(
|(idxa, (sa, _)), (idxb, (sb, _))| {
if sa != sb {
sa.cmp(sb)
} else {
idxa.cmp(idxb)
}
},
);
// Select the matches (if some are overlapping) we want to keep
let mut i = 0;
let mut current_offset = 0;
let mut splits = Vec::with_capacity(matches.len());
while i < matches.len() {
let (idx, (start, end)) = matches[i];
// current match is before the currentt offset, let's skip it
if start < current_offset {
i += 1;
continue;
}
// Find out if we have overlapping neighbors. If so, we keep the one with the lowest
// idx, and apply it, then continue. All others will be skipped since `current_offset`
// will have been increased
if i + 1 < matches.len() {
if let Some((idx, (s, e))) = matches[i..]
.iter()
.take_while(|(_, (s, e))| *s < end && start < *e)
.min() // Order on idx first
.copied()
{
splits.push((idx, (s, e)));
current_offset = e;
i += 1;
continue;
}
}
// We didn't find overlapping neighbors, apply ourself
splits.push((idx, (start, end)));
current_offset = end;
i += 1;
}
// We also insert the splits that are inbetween the added tokens, to split the entire string
let mut start_offset = 0;
let mut splits = splits
.into_iter()
.flat_map(|(idx, (start, end))| {
let mut splits = vec![];
for mat in split_re.0.find_iter(sentence) {
let mut start = mat.start();
let mut stop = mat.end();
let aho_id = mat.pattern();
let id = split_re.1[aho_id];
let added_token = &self.added_tokens_map_r.get(&id).unwrap();
if added_token.single_word {
let start_space = start == 0 || sentence.as_bytes().get(start - 1) == Some(&b' ');
let stop_space =
stop == sentence.len() || sentence.as_bytes().get(stop + 1) == Some(&b' ');
if !stop_space || !start_space {
// Discard not single word
continue;
}
}
if added_token.lstrip {
while start > 0 {
if let Some(b' ') = sentence.as_bytes().get(start - 1) {
start -= 1;
} else {
break;
}
}
}
if added_token.rstrip {
while stop < sentence.len() {
if let Some(b' ') = sentence.as_bytes().get(stop) {
stop += 1;
} else {
break;
}
}
}
if start_offset < start {
splits.push((None, (start_offset, start)));
}
splits.push((Some(split_re.1[idx] as u32), (start, end)));
start_offset = end;
splits
})
.collect::<Vec<_>>();
splits.push((Some(id), (start, stop)));
start_offset = stop;
}
let total_byte_len = sentence.len();
if start_offset != total_byte_len {
@ -445,14 +384,14 @@ impl AddedVocabulary {
// 1. We extract all the non-normalized tokens from the non-normalized string
pretokenized
.split(|_, sequence| Ok(self.split_with_indices(sequence, &self.split_re)))
.split(|_, sequence| Ok(self.split_with_indices(sequence, &self.split_trie)))
.expect("AddedVocabulary bad split");
// 2. Then extract the normalized tokens from the normalized pieces of the string
pretokenized
.split(|_, mut sequence| {
normalizer.map(|n| n.normalize(&mut sequence));
Ok(self.split_with_indices(sequence, &self.split_normalized_re))
Ok(self.split_with_indices(sequence, &self.split_normalized_trie))
})
.expect("AddedVocabulary bad split");
@ -756,7 +695,7 @@ mod tests {
#[test]
fn empty_matches() {
let vocab = AddedVocabulary::new();
let matches = vocab.find_matches("", &vocab.split_re);
let matches = vocab.find_matches("", &vocab.split_trie);
assert_eq!(matches, vec![(None, (0, 0))]);
}
}

View File

@ -116,5 +116,12 @@ fn overlapping_tokens() {
let output = tokenizer.encode(input, false).unwrap();
assert_eq!(output.get_tokens(), &["I", "Ġl", "ike", "Ġda", "nci", "ng"]);
// Breaking change but following `transformers` breaking change.
// This behavior is deemed not used in practice:
// https://github.com/huggingface/transformers/pull/13220
// Order does NOT matter. (We could make it work again but the trie
// would need to keep insertion order too)
//
// assert_eq!(output.get_tokens(), &["I", "Ġlike", "Ġda", "nci", "ng"]);
assert_eq!(output.get_tokens(), &["I", "Ġl", "ike", "Ġ", "danc", "ing"]);
}