From 66be62b6e637586bb0facbfb36e8c6453ea8eb8f Mon Sep 17 00:00:00 2001 From: Anthony MOI Date: Fri, 12 Jun 2020 19:49:10 -0400 Subject: [PATCH] Rust - Extract AddedVocabulary management from Tokenizer --- tokenizers/src/models/bpe/model.rs | 4 +- tokenizers/src/models/wordlevel/mod.rs | 4 +- tokenizers/src/models/wordpiece/mod.rs | 4 +- tokenizers/src/tokenizer/mod.rs | 344 +++------------------- tokenizers/src/tokenizer/serialization.rs | 29 +- 5 files changed, 46 insertions(+), 339 deletions(-) diff --git a/tokenizers/src/models/bpe/model.rs b/tokenizers/src/models/bpe/model.rs index a26f2241..174e8d3b 100644 --- a/tokenizers/src/models/bpe/model.rs +++ b/tokenizers/src/models/bpe/model.rs @@ -433,8 +433,8 @@ impl Model for BPE { self.vocab.get(token).copied() } - fn id_to_token(&self, id: u32) -> Option { - self.vocab_r.get(&id).cloned() + fn id_to_token(&self, id: u32) -> Option<&str> { + self.vocab_r.get(&id).map(String::as_ref) } fn save(&self, folder: &Path, name: Option<&str>) -> Result> { diff --git a/tokenizers/src/models/wordlevel/mod.rs b/tokenizers/src/models/wordlevel/mod.rs index edaf32fa..f8c0dd93 100644 --- a/tokenizers/src/models/wordlevel/mod.rs +++ b/tokenizers/src/models/wordlevel/mod.rs @@ -169,8 +169,8 @@ impl Model for WordLevel { self.vocab.get(token).copied() } - fn id_to_token(&self, id: u32) -> Option { - self.vocab_r.get(&id).cloned() + fn id_to_token(&self, id: u32) -> Option<&str> { + self.vocab_r.get(&id).map(String::as_ref) } fn get_vocab(&self) -> &HashMap { diff --git a/tokenizers/src/models/wordpiece/mod.rs b/tokenizers/src/models/wordpiece/mod.rs index 549ae092..339b5f87 100644 --- a/tokenizers/src/models/wordpiece/mod.rs +++ b/tokenizers/src/models/wordpiece/mod.rs @@ -283,8 +283,8 @@ impl Model for WordPiece { self.vocab.get(token).copied() } - fn id_to_token(&self, id: u32) -> Option { - self.vocab_r.get(&id).cloned() + fn id_to_token(&self, id: u32) -> Option<&str> { + self.vocab_r.get(&id).map(String::as_ref) } fn save(&self, folder: &Path, name: Option<&str>) -> Result> { diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index eff8531d..816ee656 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -15,19 +15,20 @@ pub use crate::utils::padding::{pad_encodings, PaddingDirection, PaddingParams, pub use crate::utils::truncation::{truncate_encodings, TruncationParams, TruncationStrategy}; use indicatif::{ProgressBar, ProgressStyle}; use rayon::prelude::*; -use serde::{Deserialize, Serialize}; use std::{ - collections::{HashMap, HashSet}, + collections::HashMap, fs::File, io::prelude::*, io::BufReader, path::{Path, PathBuf}, }; +mod added_vocabulary; mod encoding; mod normalizer; mod serialization; +pub use added_vocabulary::*; pub use encoding::*; pub use normalizer::*; @@ -56,7 +57,7 @@ pub trait PreTokenizer: Send + Sync { pub trait Model: Send + Sync { fn tokenize(&self, tokens: Vec<(String, Offsets)>) -> Result>; fn token_to_id(&self, token: &str) -> Option; - fn id_to_token(&self, id: u32) -> Option; + fn id_to_token(&self, id: u32) -> Option<&str>; fn get_vocab(&self) -> &HashMap; fn get_vocab_size(&self) -> usize; fn save(&self, folder: &Path, name: Option<&str>) -> Result>; @@ -182,102 +183,6 @@ impl, I2: Into> From<(I1, I2)> for Encode } } -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct AddedToken { - /// The content of the added token - pub content: String, - /// Whether this token must be a single word or can break words - pub single_word: bool, - /// Whether this token should strip whitespaces on its left - pub lstrip: bool, - /// Whether this token should strip whitespaces on its right - pub rstrip: bool, -} -impl AddedToken { - pub fn from(content: String) -> Self { - AddedToken { - content, - ..Default::default() - } - } - pub fn single_word(mut self, single_word: bool) -> Self { - self.single_word = single_word; - self - } - pub fn lstrip(mut self, lstrip: bool) -> Self { - self.lstrip = lstrip; - self - } - pub fn rstrip(mut self, rstrip: bool) -> Self { - self.rstrip = rstrip; - self - } - pub fn get_pattern(&self) -> String { - let mut r = if self.single_word { - let first_b = self - .content - .chars() - .next() - .map(|c| { - if regex_syntax::is_word_character(c) { - r"\b" - } else { - "" - } - }) - .unwrap(); - let last_b = self - .content - .chars() - .last() - .map(|c| { - if regex_syntax::is_word_character(c) { - r"\b" - } else { - "" - } - }) - .unwrap(); - format!(r"{}{}{}", first_b, regex::escape(&self.content), last_b) - } else { - regex::escape(&self.content) - }; - - if self.lstrip && self.rstrip { - r = format!(r"(\s)?{}(\s)?", r); - } else if self.lstrip { - r = format!(r"(\s)?{}", r); - } else if self.rstrip { - r = format!(r"{}(\s)?", r); - } - - r - } -} -impl Default for AddedToken { - fn default() -> Self { - AddedToken { - content: String::new(), - single_word: false, - lstrip: false, - rstrip: false, - } - } -} -// We only want to hash on the content. AddedToken cannot be added multiple times with different -// options -impl std::hash::Hash for AddedToken { - fn hash(&self, state: &mut H) { - self.content.hash(state); - } -} -impl std::cmp::PartialEq for AddedToken { - fn eq(&self, other: &Self) -> bool { - self.content == other.content - } -} -impl std::cmp::Eq for AddedToken {} - /// A `Tokenizer` is capable of encoding/decoding any text. pub struct Tokenizer { // Tokenizer parts @@ -288,21 +193,7 @@ pub struct Tokenizer { decoder: Option>, // Added Vocabulary capabilities - /// Contains the mapping from String to ID as the user intended it. This map - /// contains both special tokens and classic added tokens. - added_tokens_map: HashMap, - /// Contains the mapping from ID to AddedToken for all the added tokens, both special - /// and classic. - added_tokens_map_r: HashMap, - /// Contains only the classic AddedToken, in the specific order the user gave them. - added_tokens: Vec, - /// Contains only the special AddedToken, in the specific order the user gave them. - special_tokens: Vec, - /// A Set, containing all the special token for easy access while decoding. This let's - /// use remove them easily with an O(1) complexity. - special_tokens_set: HashSet, - /// A RegexSet containing all the patterns used to split on AddedTokens - split_re: regex::RegexSet, + added_vocabulary: AddedVocabulary, // General processing parameters truncation: Option, @@ -327,12 +218,7 @@ impl Tokenizer { post_processor: None, decoder: None, - added_tokens_map: HashMap::new(), - added_tokens_map_r: HashMap::new(), - added_tokens: vec![], - special_tokens: vec![], - special_tokens_set: HashSet::new(), - split_re: regex::RegexSet::new::<_, &&str>(&[]).unwrap(), + added_vocabulary: AddedVocabulary::new(), truncation: None, padding: None, @@ -461,10 +347,13 @@ impl Tokenizer { pub fn get_vocab(&self, with_added_tokens: bool) -> HashMap { let mut final_vocab = self.model.get_vocab().clone(); - if with_added_tokens && !self.added_tokens_map.is_empty() { - final_vocab.reserve(self.added_tokens_map.len()); - for (token, id) in &self.added_tokens_map { - final_vocab.insert(token.clone(), *id); + if with_added_tokens { + let added_vocab = self.added_vocabulary.get_vocab(); + if !added_vocab.is_empty() { + final_vocab.reserve(added_vocab.len()); + for (token, id) in added_vocab { + final_vocab.insert(token.clone(), *id); + } } } @@ -475,7 +364,7 @@ impl Tokenizer { pub fn get_vocab_size(&self, with_added_tokens: bool) -> usize { self.model.get_vocab_size() + if with_added_tokens { - self.added_tokens_map.len() + self.added_vocabulary.len() } else { 0 } @@ -483,26 +372,24 @@ impl Tokenizer { /// Converts a token in the corresponding id. pub fn token_to_id(&self, token: &str) -> Option { - if let Some(id) = self.added_tokens_map.get(token) { - Some(*id) - } else { - self.model.token_to_id(token) - } + self.added_vocabulary + .token_to_id(token) + .copied() + .or_else(|| self.model.token_to_id(token)) } /// Converts an id to the corresponding token. - pub fn id_to_token(&self, id: u32) -> Option { - if let Some(token) = self.added_tokens_map_r.get(&id) { - Some(token.content.clone()) - } else { - self.model.id_to_token(id) - } + pub fn id_to_token(&self, id: u32) -> Option<&str> { + self.added_vocabulary + .id_to_token(id) + .or_else(|| self.model.id_to_token(id)) } /// Normalize the given sentence and return the corresponding normalized string pub fn normalize(&self, sentence: &str) -> Result { let mut normalized = self - .split_on_added_tokens(sentence) + .added_vocabulary + .extract(sentence) .into_iter() .map(|(sentence, id)| -> Result { if id.is_some() { @@ -534,7 +421,7 @@ impl Tokenizer { let mut sequence_encodings = vec![]; for subseq in sequence { - let results = self.split_on_added_tokens(&subseq).into_iter().map( + let results = self.added_vocabulary.extract(&subseq).into_iter().map( |(sentence, id)| -> Result<(Encoding, NormalizedString)> { if let Some(id) = id { Ok(( @@ -666,15 +553,17 @@ impl Tokenizer { let tokens = ids .into_iter() .map(|id| { - let token = if let Some(token) = self.added_tokens_map_r.get(&id) { - Some(token.content.to_owned()) + let token = if let Some(token) = self.added_vocabulary.id_to_token(id) { + Some(token) } else { self.model.id_to_token(id) }; - token.filter(|token| { - !skip_special_tokens || !self.special_tokens_set.contains(token) - }) + token + .filter(|token| { + !skip_special_tokens || !self.added_vocabulary.is_special_token(token) + }) + .map(|t| t.to_owned()) }) .filter(|token| token.is_some()) .map(|id| id.unwrap()) @@ -863,174 +752,13 @@ impl Tokenizer { /// Register the given tokens as special tokens. This is especially useful for removing /// these special tokens while decoding pub fn add_special_tokens(&mut self, tokens: &[AddedToken]) -> usize { - for token in tokens { - if !self.special_tokens_set.contains(&token.content) { - self.special_tokens.push(token.to_owned()); - self.special_tokens_set.insert(token.content.clone()); - } - } - let added = self.add_tokens(&tokens); - - self.refresh_added_tokens(); - - added + self.added_vocabulary + .add_special_tokens(tokens, self.model.as_ref()) } /// Add the given tokens to the added vocabulary pub fn add_tokens(&mut self, tokens: &[AddedToken]) -> usize { - let mut ignored = 0; - for token in tokens { - if token.content.is_empty() { - ignored += 1; - continue; - } - - let id = if let Some(id) = self.token_to_id(&token.content) { - ignored += 1; - id - } else { - let new_id = (self.model.get_vocab_size() + self.added_tokens_map.len()) as u32; - self.added_tokens_map.insert(token.content.clone(), new_id); - - if !self.special_tokens_set.contains(&token.content) { - self.added_tokens.push(token.clone()); - } - - new_id - }; - - // Update the current revert operation - self.added_tokens_map_r - .entry(id) - .and_modify(|t| *t = token.clone()) - .or_insert_with(|| token.clone()); - } - - self.refresh_added_tokens(); - - // Return the number of added tokens - tokens.len() - ignored - } - - fn refresh_added_tokens(&mut self) { - self.split_re = regex::RegexSet::new( - self.special_tokens - .iter() - .chain(self.added_tokens.iter()) - .map(|token| token.get_pattern()), - ) - .unwrap(); - } - - /// Split the given sentence on multiple parts, finding the added tokens and their id in - /// the process - fn split_on_added_tokens(&self, sentence: &str) -> Vec<(NormalizedString, Option)> { - let mut matches = self - .split_re - .matches(sentence) - .into_iter() - .flat_map(|idx| { - regex::Regex::new(&self.split_re.patterns()[idx]) - .unwrap() - .find_iter(&sentence) - .map(|m| (idx, (m.start(), m.end()))) - .collect::>() - }) - .collect::>(); - - // We sort all the matches by their start and then by their pattern id - matches.sort_by( - |(idxa, (sa, _)), (idxb, (sb, _))| { - if sa != sb { - sa.cmp(sb) - } else { - idxa.cmp(idxb) - } - }, - ); - - // Select the matches (if some are overlapping) we want to keep - let mut i = 0; - let mut current_offset = 0; - let mut splits = Vec::with_capacity(matches.len()); - while i < matches.len() { - let (idx, (start, end)) = matches[i]; - - // current match is before the currentt offset, let's skip it - if start < current_offset { - i += 1; - continue; - } - - // Find out if we have overlapping neighbors. If so, we keep the one with the lowest - // idx, and apply it, then continue. All others will be skipped since `current_offset` - // will have been increased - if i + 1 < matches.len() { - if let Some((idx, (s, e))) = matches[i..] - .iter() - .take_while(|(_, (s, e))| *s < end && start < *e) - .min() // Order on idx first - .copied() - { - splits.push((idx, (s, e))); - current_offset = e; - i += 1; - continue; - } - } - - // We didn't find overlapping neighbors, apply ourself - splits.push((idx, (start, end))); - current_offset = end; - i += 1; - } - - // We also insert the splits that are inbetween the added tokens, to split the entire string - let mut start_offset = 0; - let mut splits = splits - .into_iter() - .flat_map(|(idx, (start, end))| { - let mut splits = vec![]; - if start_offset < start { - splits.push((None, (start_offset, start))); - } - splits.push((Some(idx), (start, end))); - start_offset = end; - - splits - }) - .collect::>(); - if let Some((_, (_, end))) = splits.iter().last().copied() { - if end < sentence.len() { - splits.push((None, (end, sentence.len()))); - } - } - - if splits.is_empty() { - vec![(NormalizedString::from(sentence), None)] - } else { - splits - .into_iter() - .map(|(idx, (start, end))| unsafe { - let s = sentence.get_unchecked(start..end).to_owned(); - let normalized = NormalizedString::from(&s); - - // Find out the associated AddedToken, and its id - let id = if let Some(idx) = idx { - let added = if idx >= self.special_tokens.len() { - &self.added_tokens[idx - self.special_tokens.len()] - } else { - &self.special_tokens[idx] - }; - - self.token_to_id(&added.content) - } else { - None - }; - - (normalized, id) - }) - .collect() - } + self.added_vocabulary + .add_tokens(tokens, self.model.as_ref()) } } diff --git a/tokenizers/src/tokenizer/serialization.rs b/tokenizers/src/tokenizer/serialization.rs index 9898ba36..551977d0 100644 --- a/tokenizers/src/tokenizer/serialization.rs +++ b/tokenizers/src/tokenizer/serialization.rs @@ -1,4 +1,4 @@ -use super::{AddedToken, Tokenizer}; +use super::{added_vocabulary::AddedTokenWithId, Tokenizer}; use crate::models::bpe::BPE; use serde::{ self, @@ -9,18 +9,6 @@ use serde::{ static SERIALIZATION_VERSION: &str = "1.0"; -#[derive(Debug, Serialize, Deserialize)] -struct AddedTokenWithId { - /// The id assigned to this token - id: u32, - /// Whether this is a special token - special: bool, - - #[serde(flatten)] - /// The target AddedToken - token: AddedToken, -} - impl Serialize for Tokenizer { fn serialize(&self, serializer: S) -> Result where @@ -36,18 +24,7 @@ impl Serialize for Tokenizer { tokenizer.serialize_field("padding", &self.padding)?; // Added tokens - let mut added_tokens = self - .added_tokens_map_r - .iter() - .map(|(id, token)| AddedTokenWithId { - id: *id, - special: self.special_tokens_set.contains(&token.content), - token: token.clone(), - }) - .collect::>(); - // We need to have these added tokens ordered by ascending ID - added_tokens.sort_unstable_by_key(|o| o.id); - tokenizer.serialize_field("added_tokens", &added_tokens)?; + tokenizer.serialize_field("added_tokens", &self.added_vocabulary)?; // Then add our parts tokenizer.serialize_field("normalizer", &self.normalizer)?; @@ -141,6 +118,8 @@ impl<'de> Visitor<'de> for TokenizerVisitor { }; } + // We take care of deserializing the added_tokens (instead of `AddedVocabulary` directly + // because it let us check that associated IDs are still good, and warn the user otherwise for token in tokens { let tk = token.token.content.clone(); if token.special {