Rust - Extract AddedVocabulary management from Tokenizer

This commit is contained in:
Anthony MOI
2020-06-12 19:49:10 -04:00
parent 6091c9b229
commit 66be62b6e6
5 changed files with 46 additions and 339 deletions

View File

@@ -433,8 +433,8 @@ impl Model for BPE {
self.vocab.get(token).copied()
}
fn id_to_token(&self, id: u32) -> Option<String> {
self.vocab_r.get(&id).cloned()
fn id_to_token(&self, id: u32) -> Option<&str> {
self.vocab_r.get(&id).map(String::as_ref)
}
fn save(&self, folder: &Path, name: Option<&str>) -> Result<Vec<PathBuf>> {

View File

@@ -169,8 +169,8 @@ impl Model for WordLevel {
self.vocab.get(token).copied()
}
fn id_to_token(&self, id: u32) -> Option<String> {
self.vocab_r.get(&id).cloned()
fn id_to_token(&self, id: u32) -> Option<&str> {
self.vocab_r.get(&id).map(String::as_ref)
}
fn get_vocab(&self) -> &HashMap<String, u32> {

View File

@@ -283,8 +283,8 @@ impl Model for WordPiece {
self.vocab.get(token).copied()
}
fn id_to_token(&self, id: u32) -> Option<String> {
self.vocab_r.get(&id).cloned()
fn id_to_token(&self, id: u32) -> Option<&str> {
self.vocab_r.get(&id).map(String::as_ref)
}
fn save(&self, folder: &Path, name: Option<&str>) -> Result<Vec<PathBuf>> {

View File

@@ -15,19 +15,20 @@ pub use crate::utils::padding::{pad_encodings, PaddingDirection, PaddingParams,
pub use crate::utils::truncation::{truncate_encodings, TruncationParams, TruncationStrategy};
use indicatif::{ProgressBar, ProgressStyle};
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use std::{
collections::{HashMap, HashSet},
collections::HashMap,
fs::File,
io::prelude::*,
io::BufReader,
path::{Path, PathBuf},
};
mod added_vocabulary;
mod encoding;
mod normalizer;
mod serialization;
pub use added_vocabulary::*;
pub use encoding::*;
pub use normalizer::*;
@@ -56,7 +57,7 @@ pub trait PreTokenizer: Send + Sync {
pub trait Model: Send + Sync {
fn tokenize(&self, tokens: Vec<(String, Offsets)>) -> Result<Vec<Token>>;
fn token_to_id(&self, token: &str) -> Option<u32>;
fn id_to_token(&self, id: u32) -> Option<String>;
fn id_to_token(&self, id: u32) -> Option<&str>;
fn get_vocab(&self) -> &HashMap<String, u32>;
fn get_vocab_size(&self) -> usize;
fn save(&self, folder: &Path, name: Option<&str>) -> Result<Vec<PathBuf>>;
@@ -182,102 +183,6 @@ impl<I1: Into<InputSequence>, I2: Into<InputSequence>> From<(I1, I2)> for Encode
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AddedToken {
/// The content of the added token
pub content: String,
/// Whether this token must be a single word or can break words
pub single_word: bool,
/// Whether this token should strip whitespaces on its left
pub lstrip: bool,
/// Whether this token should strip whitespaces on its right
pub rstrip: bool,
}
impl AddedToken {
pub fn from(content: String) -> Self {
AddedToken {
content,
..Default::default()
}
}
pub fn single_word(mut self, single_word: bool) -> Self {
self.single_word = single_word;
self
}
pub fn lstrip(mut self, lstrip: bool) -> Self {
self.lstrip = lstrip;
self
}
pub fn rstrip(mut self, rstrip: bool) -> Self {
self.rstrip = rstrip;
self
}
pub fn get_pattern(&self) -> String {
let mut r = if self.single_word {
let first_b = self
.content
.chars()
.next()
.map(|c| {
if regex_syntax::is_word_character(c) {
r"\b"
} else {
""
}
})
.unwrap();
let last_b = self
.content
.chars()
.last()
.map(|c| {
if regex_syntax::is_word_character(c) {
r"\b"
} else {
""
}
})
.unwrap();
format!(r"{}{}{}", first_b, regex::escape(&self.content), last_b)
} else {
regex::escape(&self.content)
};
if self.lstrip && self.rstrip {
r = format!(r"(\s)?{}(\s)?", r);
} else if self.lstrip {
r = format!(r"(\s)?{}", r);
} else if self.rstrip {
r = format!(r"{}(\s)?", r);
}
r
}
}
impl Default for AddedToken {
fn default() -> Self {
AddedToken {
content: String::new(),
single_word: false,
lstrip: false,
rstrip: false,
}
}
}
// We only want to hash on the content. AddedToken cannot be added multiple times with different
// options
impl std::hash::Hash for AddedToken {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.content.hash(state);
}
}
impl std::cmp::PartialEq for AddedToken {
fn eq(&self, other: &Self) -> bool {
self.content == other.content
}
}
impl std::cmp::Eq for AddedToken {}
/// A `Tokenizer` is capable of encoding/decoding any text.
pub struct Tokenizer {
// Tokenizer parts
@@ -288,21 +193,7 @@ pub struct Tokenizer {
decoder: Option<Box<dyn Decoder>>,
// Added Vocabulary capabilities
/// Contains the mapping from String to ID as the user intended it. This map
/// contains both special tokens and classic added tokens.
added_tokens_map: HashMap<String, u32>,
/// Contains the mapping from ID to AddedToken for all the added tokens, both special
/// and classic.
added_tokens_map_r: HashMap<u32, AddedToken>,
/// Contains only the classic AddedToken, in the specific order the user gave them.
added_tokens: Vec<AddedToken>,
/// Contains only the special AddedToken, in the specific order the user gave them.
special_tokens: Vec<AddedToken>,
/// A Set, containing all the special token for easy access while decoding. This let's
/// use remove them easily with an O(1) complexity.
special_tokens_set: HashSet<String>,
/// A RegexSet containing all the patterns used to split on AddedTokens
split_re: regex::RegexSet,
added_vocabulary: AddedVocabulary,
// General processing parameters
truncation: Option<TruncationParams>,
@@ -327,12 +218,7 @@ impl Tokenizer {
post_processor: None,
decoder: None,
added_tokens_map: HashMap::new(),
added_tokens_map_r: HashMap::new(),
added_tokens: vec![],
special_tokens: vec![],
special_tokens_set: HashSet::new(),
split_re: regex::RegexSet::new::<_, &&str>(&[]).unwrap(),
added_vocabulary: AddedVocabulary::new(),
truncation: None,
padding: None,
@@ -461,12 +347,15 @@ impl Tokenizer {
pub fn get_vocab(&self, with_added_tokens: bool) -> HashMap<String, u32> {
let mut final_vocab = self.model.get_vocab().clone();
if with_added_tokens && !self.added_tokens_map.is_empty() {
final_vocab.reserve(self.added_tokens_map.len());
for (token, id) in &self.added_tokens_map {
if with_added_tokens {
let added_vocab = self.added_vocabulary.get_vocab();
if !added_vocab.is_empty() {
final_vocab.reserve(added_vocab.len());
for (token, id) in added_vocab {
final_vocab.insert(token.clone(), *id);
}
}
}
final_vocab
}
@@ -475,7 +364,7 @@ impl Tokenizer {
pub fn get_vocab_size(&self, with_added_tokens: bool) -> usize {
self.model.get_vocab_size()
+ if with_added_tokens {
self.added_tokens_map.len()
self.added_vocabulary.len()
} else {
0
}
@@ -483,26 +372,24 @@ impl Tokenizer {
/// Converts a token in the corresponding id.
pub fn token_to_id(&self, token: &str) -> Option<u32> {
if let Some(id) = self.added_tokens_map.get(token) {
Some(*id)
} else {
self.model.token_to_id(token)
}
self.added_vocabulary
.token_to_id(token)
.copied()
.or_else(|| self.model.token_to_id(token))
}
/// Converts an id to the corresponding token.
pub fn id_to_token(&self, id: u32) -> Option<String> {
if let Some(token) = self.added_tokens_map_r.get(&id) {
Some(token.content.clone())
} else {
self.model.id_to_token(id)
}
pub fn id_to_token(&self, id: u32) -> Option<&str> {
self.added_vocabulary
.id_to_token(id)
.or_else(|| self.model.id_to_token(id))
}
/// Normalize the given sentence and return the corresponding normalized string
pub fn normalize(&self, sentence: &str) -> Result<NormalizedString> {
let mut normalized = self
.split_on_added_tokens(sentence)
.added_vocabulary
.extract(sentence)
.into_iter()
.map(|(sentence, id)| -> Result<NormalizedString> {
if id.is_some() {
@@ -534,7 +421,7 @@ impl Tokenizer {
let mut sequence_encodings = vec![];
for subseq in sequence {
let results = self.split_on_added_tokens(&subseq).into_iter().map(
let results = self.added_vocabulary.extract(&subseq).into_iter().map(
|(sentence, id)| -> Result<(Encoding, NormalizedString)> {
if let Some(id) = id {
Ok((
@@ -666,15 +553,17 @@ impl Tokenizer {
let tokens = ids
.into_iter()
.map(|id| {
let token = if let Some(token) = self.added_tokens_map_r.get(&id) {
Some(token.content.to_owned())
let token = if let Some(token) = self.added_vocabulary.id_to_token(id) {
Some(token)
} else {
self.model.id_to_token(id)
};
token.filter(|token| {
!skip_special_tokens || !self.special_tokens_set.contains(token)
token
.filter(|token| {
!skip_special_tokens || !self.added_vocabulary.is_special_token(token)
})
.map(|t| t.to_owned())
})
.filter(|token| token.is_some())
.map(|id| id.unwrap())
@@ -863,174 +752,13 @@ impl Tokenizer {
/// Register the given tokens as special tokens. This is especially useful for removing
/// these special tokens while decoding
pub fn add_special_tokens(&mut self, tokens: &[AddedToken]) -> usize {
for token in tokens {
if !self.special_tokens_set.contains(&token.content) {
self.special_tokens.push(token.to_owned());
self.special_tokens_set.insert(token.content.clone());
}
}
let added = self.add_tokens(&tokens);
self.refresh_added_tokens();
added
self.added_vocabulary
.add_special_tokens(tokens, self.model.as_ref())
}
/// Add the given tokens to the added vocabulary
pub fn add_tokens(&mut self, tokens: &[AddedToken]) -> usize {
let mut ignored = 0;
for token in tokens {
if token.content.is_empty() {
ignored += 1;
continue;
}
let id = if let Some(id) = self.token_to_id(&token.content) {
ignored += 1;
id
} else {
let new_id = (self.model.get_vocab_size() + self.added_tokens_map.len()) as u32;
self.added_tokens_map.insert(token.content.clone(), new_id);
if !self.special_tokens_set.contains(&token.content) {
self.added_tokens.push(token.clone());
}
new_id
};
// Update the current revert operation
self.added_tokens_map_r
.entry(id)
.and_modify(|t| *t = token.clone())
.or_insert_with(|| token.clone());
}
self.refresh_added_tokens();
// Return the number of added tokens
tokens.len() - ignored
}
fn refresh_added_tokens(&mut self) {
self.split_re = regex::RegexSet::new(
self.special_tokens
.iter()
.chain(self.added_tokens.iter())
.map(|token| token.get_pattern()),
)
.unwrap();
}
/// Split the given sentence on multiple parts, finding the added tokens and their id in
/// the process
fn split_on_added_tokens(&self, sentence: &str) -> Vec<(NormalizedString, Option<u32>)> {
let mut matches = self
.split_re
.matches(sentence)
.into_iter()
.flat_map(|idx| {
regex::Regex::new(&self.split_re.patterns()[idx])
.unwrap()
.find_iter(&sentence)
.map(|m| (idx, (m.start(), m.end())))
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();
// We sort all the matches by their start and then by their pattern id
matches.sort_by(
|(idxa, (sa, _)), (idxb, (sb, _))| {
if sa != sb {
sa.cmp(sb)
} else {
idxa.cmp(idxb)
}
},
);
// Select the matches (if some are overlapping) we want to keep
let mut i = 0;
let mut current_offset = 0;
let mut splits = Vec::with_capacity(matches.len());
while i < matches.len() {
let (idx, (start, end)) = matches[i];
// current match is before the currentt offset, let's skip it
if start < current_offset {
i += 1;
continue;
}
// Find out if we have overlapping neighbors. If so, we keep the one with the lowest
// idx, and apply it, then continue. All others will be skipped since `current_offset`
// will have been increased
if i + 1 < matches.len() {
if let Some((idx, (s, e))) = matches[i..]
.iter()
.take_while(|(_, (s, e))| *s < end && start < *e)
.min() // Order on idx first
.copied()
{
splits.push((idx, (s, e)));
current_offset = e;
i += 1;
continue;
}
}
// We didn't find overlapping neighbors, apply ourself
splits.push((idx, (start, end)));
current_offset = end;
i += 1;
}
// We also insert the splits that are inbetween the added tokens, to split the entire string
let mut start_offset = 0;
let mut splits = splits
.into_iter()
.flat_map(|(idx, (start, end))| {
let mut splits = vec![];
if start_offset < start {
splits.push((None, (start_offset, start)));
}
splits.push((Some(idx), (start, end)));
start_offset = end;
splits
})
.collect::<Vec<_>>();
if let Some((_, (_, end))) = splits.iter().last().copied() {
if end < sentence.len() {
splits.push((None, (end, sentence.len())));
}
}
if splits.is_empty() {
vec![(NormalizedString::from(sentence), None)]
} else {
splits
.into_iter()
.map(|(idx, (start, end))| unsafe {
let s = sentence.get_unchecked(start..end).to_owned();
let normalized = NormalizedString::from(&s);
// Find out the associated AddedToken, and its id
let id = if let Some(idx) = idx {
let added = if idx >= self.special_tokens.len() {
&self.added_tokens[idx - self.special_tokens.len()]
} else {
&self.special_tokens[idx]
};
self.token_to_id(&added.content)
} else {
None
};
(normalized, id)
})
.collect()
}
self.added_vocabulary
.add_tokens(tokens, self.model.as_ref())
}
}

View File

@@ -1,4 +1,4 @@
use super::{AddedToken, Tokenizer};
use super::{added_vocabulary::AddedTokenWithId, Tokenizer};
use crate::models::bpe::BPE;
use serde::{
self,
@@ -9,18 +9,6 @@ use serde::{
static SERIALIZATION_VERSION: &str = "1.0";
#[derive(Debug, Serialize, Deserialize)]
struct AddedTokenWithId {
/// The id assigned to this token
id: u32,
/// Whether this is a special token
special: bool,
#[serde(flatten)]
/// The target AddedToken
token: AddedToken,
}
impl Serialize for Tokenizer {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
@@ -36,18 +24,7 @@ impl Serialize for Tokenizer {
tokenizer.serialize_field("padding", &self.padding)?;
// Added tokens
let mut added_tokens = self
.added_tokens_map_r
.iter()
.map(|(id, token)| AddedTokenWithId {
id: *id,
special: self.special_tokens_set.contains(&token.content),
token: token.clone(),
})
.collect::<Vec<_>>();
// We need to have these added tokens ordered by ascending ID
added_tokens.sort_unstable_by_key(|o| o.id);
tokenizer.serialize_field("added_tokens", &added_tokens)?;
tokenizer.serialize_field("added_tokens", &self.added_vocabulary)?;
// Then add our parts
tokenizer.serialize_field("normalizer", &self.normalizer)?;
@@ -141,6 +118,8 @@ impl<'de> Visitor<'de> for TokenizerVisitor {
};
}
// We take care of deserializing the added_tokens (instead of `AddedVocabulary` directly
// because it let us check that associated IDs are still good, and warn the user otherwise
for token in tokens {
let tk = token.token.content.clone();
if token.special {