mirror of
https://github.com/mii443/tokenizers.git
synced 2025-12-05 04:08:22 +00:00
Rust - Extract AddedVocabulary management from Tokenizer
This commit is contained in:
@@ -433,8 +433,8 @@ impl Model for BPE {
|
||||
self.vocab.get(token).copied()
|
||||
}
|
||||
|
||||
fn id_to_token(&self, id: u32) -> Option<String> {
|
||||
self.vocab_r.get(&id).cloned()
|
||||
fn id_to_token(&self, id: u32) -> Option<&str> {
|
||||
self.vocab_r.get(&id).map(String::as_ref)
|
||||
}
|
||||
|
||||
fn save(&self, folder: &Path, name: Option<&str>) -> Result<Vec<PathBuf>> {
|
||||
|
||||
@@ -169,8 +169,8 @@ impl Model for WordLevel {
|
||||
self.vocab.get(token).copied()
|
||||
}
|
||||
|
||||
fn id_to_token(&self, id: u32) -> Option<String> {
|
||||
self.vocab_r.get(&id).cloned()
|
||||
fn id_to_token(&self, id: u32) -> Option<&str> {
|
||||
self.vocab_r.get(&id).map(String::as_ref)
|
||||
}
|
||||
|
||||
fn get_vocab(&self) -> &HashMap<String, u32> {
|
||||
|
||||
@@ -283,8 +283,8 @@ impl Model for WordPiece {
|
||||
self.vocab.get(token).copied()
|
||||
}
|
||||
|
||||
fn id_to_token(&self, id: u32) -> Option<String> {
|
||||
self.vocab_r.get(&id).cloned()
|
||||
fn id_to_token(&self, id: u32) -> Option<&str> {
|
||||
self.vocab_r.get(&id).map(String::as_ref)
|
||||
}
|
||||
|
||||
fn save(&self, folder: &Path, name: Option<&str>) -> Result<Vec<PathBuf>> {
|
||||
|
||||
@@ -15,19 +15,20 @@ pub use crate::utils::padding::{pad_encodings, PaddingDirection, PaddingParams,
|
||||
pub use crate::utils::truncation::{truncate_encodings, TruncationParams, TruncationStrategy};
|
||||
use indicatif::{ProgressBar, ProgressStyle};
|
||||
use rayon::prelude::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
collections::HashMap,
|
||||
fs::File,
|
||||
io::prelude::*,
|
||||
io::BufReader,
|
||||
path::{Path, PathBuf},
|
||||
};
|
||||
|
||||
mod added_vocabulary;
|
||||
mod encoding;
|
||||
mod normalizer;
|
||||
mod serialization;
|
||||
|
||||
pub use added_vocabulary::*;
|
||||
pub use encoding::*;
|
||||
pub use normalizer::*;
|
||||
|
||||
@@ -56,7 +57,7 @@ pub trait PreTokenizer: Send + Sync {
|
||||
pub trait Model: Send + Sync {
|
||||
fn tokenize(&self, tokens: Vec<(String, Offsets)>) -> Result<Vec<Token>>;
|
||||
fn token_to_id(&self, token: &str) -> Option<u32>;
|
||||
fn id_to_token(&self, id: u32) -> Option<String>;
|
||||
fn id_to_token(&self, id: u32) -> Option<&str>;
|
||||
fn get_vocab(&self) -> &HashMap<String, u32>;
|
||||
fn get_vocab_size(&self) -> usize;
|
||||
fn save(&self, folder: &Path, name: Option<&str>) -> Result<Vec<PathBuf>>;
|
||||
@@ -182,102 +183,6 @@ impl<I1: Into<InputSequence>, I2: Into<InputSequence>> From<(I1, I2)> for Encode
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AddedToken {
|
||||
/// The content of the added token
|
||||
pub content: String,
|
||||
/// Whether this token must be a single word or can break words
|
||||
pub single_word: bool,
|
||||
/// Whether this token should strip whitespaces on its left
|
||||
pub lstrip: bool,
|
||||
/// Whether this token should strip whitespaces on its right
|
||||
pub rstrip: bool,
|
||||
}
|
||||
impl AddedToken {
|
||||
pub fn from(content: String) -> Self {
|
||||
AddedToken {
|
||||
content,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
pub fn single_word(mut self, single_word: bool) -> Self {
|
||||
self.single_word = single_word;
|
||||
self
|
||||
}
|
||||
pub fn lstrip(mut self, lstrip: bool) -> Self {
|
||||
self.lstrip = lstrip;
|
||||
self
|
||||
}
|
||||
pub fn rstrip(mut self, rstrip: bool) -> Self {
|
||||
self.rstrip = rstrip;
|
||||
self
|
||||
}
|
||||
pub fn get_pattern(&self) -> String {
|
||||
let mut r = if self.single_word {
|
||||
let first_b = self
|
||||
.content
|
||||
.chars()
|
||||
.next()
|
||||
.map(|c| {
|
||||
if regex_syntax::is_word_character(c) {
|
||||
r"\b"
|
||||
} else {
|
||||
""
|
||||
}
|
||||
})
|
||||
.unwrap();
|
||||
let last_b = self
|
||||
.content
|
||||
.chars()
|
||||
.last()
|
||||
.map(|c| {
|
||||
if regex_syntax::is_word_character(c) {
|
||||
r"\b"
|
||||
} else {
|
||||
""
|
||||
}
|
||||
})
|
||||
.unwrap();
|
||||
format!(r"{}{}{}", first_b, regex::escape(&self.content), last_b)
|
||||
} else {
|
||||
regex::escape(&self.content)
|
||||
};
|
||||
|
||||
if self.lstrip && self.rstrip {
|
||||
r = format!(r"(\s)?{}(\s)?", r);
|
||||
} else if self.lstrip {
|
||||
r = format!(r"(\s)?{}", r);
|
||||
} else if self.rstrip {
|
||||
r = format!(r"{}(\s)?", r);
|
||||
}
|
||||
|
||||
r
|
||||
}
|
||||
}
|
||||
impl Default for AddedToken {
|
||||
fn default() -> Self {
|
||||
AddedToken {
|
||||
content: String::new(),
|
||||
single_word: false,
|
||||
lstrip: false,
|
||||
rstrip: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
// We only want to hash on the content. AddedToken cannot be added multiple times with different
|
||||
// options
|
||||
impl std::hash::Hash for AddedToken {
|
||||
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
||||
self.content.hash(state);
|
||||
}
|
||||
}
|
||||
impl std::cmp::PartialEq for AddedToken {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.content == other.content
|
||||
}
|
||||
}
|
||||
impl std::cmp::Eq for AddedToken {}
|
||||
|
||||
/// A `Tokenizer` is capable of encoding/decoding any text.
|
||||
pub struct Tokenizer {
|
||||
// Tokenizer parts
|
||||
@@ -288,21 +193,7 @@ pub struct Tokenizer {
|
||||
decoder: Option<Box<dyn Decoder>>,
|
||||
|
||||
// Added Vocabulary capabilities
|
||||
/// Contains the mapping from String to ID as the user intended it. This map
|
||||
/// contains both special tokens and classic added tokens.
|
||||
added_tokens_map: HashMap<String, u32>,
|
||||
/// Contains the mapping from ID to AddedToken for all the added tokens, both special
|
||||
/// and classic.
|
||||
added_tokens_map_r: HashMap<u32, AddedToken>,
|
||||
/// Contains only the classic AddedToken, in the specific order the user gave them.
|
||||
added_tokens: Vec<AddedToken>,
|
||||
/// Contains only the special AddedToken, in the specific order the user gave them.
|
||||
special_tokens: Vec<AddedToken>,
|
||||
/// A Set, containing all the special token for easy access while decoding. This let's
|
||||
/// use remove them easily with an O(1) complexity.
|
||||
special_tokens_set: HashSet<String>,
|
||||
/// A RegexSet containing all the patterns used to split on AddedTokens
|
||||
split_re: regex::RegexSet,
|
||||
added_vocabulary: AddedVocabulary,
|
||||
|
||||
// General processing parameters
|
||||
truncation: Option<TruncationParams>,
|
||||
@@ -327,12 +218,7 @@ impl Tokenizer {
|
||||
post_processor: None,
|
||||
decoder: None,
|
||||
|
||||
added_tokens_map: HashMap::new(),
|
||||
added_tokens_map_r: HashMap::new(),
|
||||
added_tokens: vec![],
|
||||
special_tokens: vec![],
|
||||
special_tokens_set: HashSet::new(),
|
||||
split_re: regex::RegexSet::new::<_, &&str>(&[]).unwrap(),
|
||||
added_vocabulary: AddedVocabulary::new(),
|
||||
|
||||
truncation: None,
|
||||
padding: None,
|
||||
@@ -461,12 +347,15 @@ impl Tokenizer {
|
||||
pub fn get_vocab(&self, with_added_tokens: bool) -> HashMap<String, u32> {
|
||||
let mut final_vocab = self.model.get_vocab().clone();
|
||||
|
||||
if with_added_tokens && !self.added_tokens_map.is_empty() {
|
||||
final_vocab.reserve(self.added_tokens_map.len());
|
||||
for (token, id) in &self.added_tokens_map {
|
||||
if with_added_tokens {
|
||||
let added_vocab = self.added_vocabulary.get_vocab();
|
||||
if !added_vocab.is_empty() {
|
||||
final_vocab.reserve(added_vocab.len());
|
||||
for (token, id) in added_vocab {
|
||||
final_vocab.insert(token.clone(), *id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
final_vocab
|
||||
}
|
||||
@@ -475,7 +364,7 @@ impl Tokenizer {
|
||||
pub fn get_vocab_size(&self, with_added_tokens: bool) -> usize {
|
||||
self.model.get_vocab_size()
|
||||
+ if with_added_tokens {
|
||||
self.added_tokens_map.len()
|
||||
self.added_vocabulary.len()
|
||||
} else {
|
||||
0
|
||||
}
|
||||
@@ -483,26 +372,24 @@ impl Tokenizer {
|
||||
|
||||
/// Converts a token in the corresponding id.
|
||||
pub fn token_to_id(&self, token: &str) -> Option<u32> {
|
||||
if let Some(id) = self.added_tokens_map.get(token) {
|
||||
Some(*id)
|
||||
} else {
|
||||
self.model.token_to_id(token)
|
||||
}
|
||||
self.added_vocabulary
|
||||
.token_to_id(token)
|
||||
.copied()
|
||||
.or_else(|| self.model.token_to_id(token))
|
||||
}
|
||||
|
||||
/// Converts an id to the corresponding token.
|
||||
pub fn id_to_token(&self, id: u32) -> Option<String> {
|
||||
if let Some(token) = self.added_tokens_map_r.get(&id) {
|
||||
Some(token.content.clone())
|
||||
} else {
|
||||
self.model.id_to_token(id)
|
||||
}
|
||||
pub fn id_to_token(&self, id: u32) -> Option<&str> {
|
||||
self.added_vocabulary
|
||||
.id_to_token(id)
|
||||
.or_else(|| self.model.id_to_token(id))
|
||||
}
|
||||
|
||||
/// Normalize the given sentence and return the corresponding normalized string
|
||||
pub fn normalize(&self, sentence: &str) -> Result<NormalizedString> {
|
||||
let mut normalized = self
|
||||
.split_on_added_tokens(sentence)
|
||||
.added_vocabulary
|
||||
.extract(sentence)
|
||||
.into_iter()
|
||||
.map(|(sentence, id)| -> Result<NormalizedString> {
|
||||
if id.is_some() {
|
||||
@@ -534,7 +421,7 @@ impl Tokenizer {
|
||||
|
||||
let mut sequence_encodings = vec![];
|
||||
for subseq in sequence {
|
||||
let results = self.split_on_added_tokens(&subseq).into_iter().map(
|
||||
let results = self.added_vocabulary.extract(&subseq).into_iter().map(
|
||||
|(sentence, id)| -> Result<(Encoding, NormalizedString)> {
|
||||
if let Some(id) = id {
|
||||
Ok((
|
||||
@@ -666,15 +553,17 @@ impl Tokenizer {
|
||||
let tokens = ids
|
||||
.into_iter()
|
||||
.map(|id| {
|
||||
let token = if let Some(token) = self.added_tokens_map_r.get(&id) {
|
||||
Some(token.content.to_owned())
|
||||
let token = if let Some(token) = self.added_vocabulary.id_to_token(id) {
|
||||
Some(token)
|
||||
} else {
|
||||
self.model.id_to_token(id)
|
||||
};
|
||||
|
||||
token.filter(|token| {
|
||||
!skip_special_tokens || !self.special_tokens_set.contains(token)
|
||||
token
|
||||
.filter(|token| {
|
||||
!skip_special_tokens || !self.added_vocabulary.is_special_token(token)
|
||||
})
|
||||
.map(|t| t.to_owned())
|
||||
})
|
||||
.filter(|token| token.is_some())
|
||||
.map(|id| id.unwrap())
|
||||
@@ -863,174 +752,13 @@ impl Tokenizer {
|
||||
/// Register the given tokens as special tokens. This is especially useful for removing
|
||||
/// these special tokens while decoding
|
||||
pub fn add_special_tokens(&mut self, tokens: &[AddedToken]) -> usize {
|
||||
for token in tokens {
|
||||
if !self.special_tokens_set.contains(&token.content) {
|
||||
self.special_tokens.push(token.to_owned());
|
||||
self.special_tokens_set.insert(token.content.clone());
|
||||
}
|
||||
}
|
||||
let added = self.add_tokens(&tokens);
|
||||
|
||||
self.refresh_added_tokens();
|
||||
|
||||
added
|
||||
self.added_vocabulary
|
||||
.add_special_tokens(tokens, self.model.as_ref())
|
||||
}
|
||||
|
||||
/// Add the given tokens to the added vocabulary
|
||||
pub fn add_tokens(&mut self, tokens: &[AddedToken]) -> usize {
|
||||
let mut ignored = 0;
|
||||
for token in tokens {
|
||||
if token.content.is_empty() {
|
||||
ignored += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
let id = if let Some(id) = self.token_to_id(&token.content) {
|
||||
ignored += 1;
|
||||
id
|
||||
} else {
|
||||
let new_id = (self.model.get_vocab_size() + self.added_tokens_map.len()) as u32;
|
||||
self.added_tokens_map.insert(token.content.clone(), new_id);
|
||||
|
||||
if !self.special_tokens_set.contains(&token.content) {
|
||||
self.added_tokens.push(token.clone());
|
||||
}
|
||||
|
||||
new_id
|
||||
};
|
||||
|
||||
// Update the current revert operation
|
||||
self.added_tokens_map_r
|
||||
.entry(id)
|
||||
.and_modify(|t| *t = token.clone())
|
||||
.or_insert_with(|| token.clone());
|
||||
}
|
||||
|
||||
self.refresh_added_tokens();
|
||||
|
||||
// Return the number of added tokens
|
||||
tokens.len() - ignored
|
||||
}
|
||||
|
||||
fn refresh_added_tokens(&mut self) {
|
||||
self.split_re = regex::RegexSet::new(
|
||||
self.special_tokens
|
||||
.iter()
|
||||
.chain(self.added_tokens.iter())
|
||||
.map(|token| token.get_pattern()),
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
/// Split the given sentence on multiple parts, finding the added tokens and their id in
|
||||
/// the process
|
||||
fn split_on_added_tokens(&self, sentence: &str) -> Vec<(NormalizedString, Option<u32>)> {
|
||||
let mut matches = self
|
||||
.split_re
|
||||
.matches(sentence)
|
||||
.into_iter()
|
||||
.flat_map(|idx| {
|
||||
regex::Regex::new(&self.split_re.patterns()[idx])
|
||||
.unwrap()
|
||||
.find_iter(&sentence)
|
||||
.map(|m| (idx, (m.start(), m.end())))
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// We sort all the matches by their start and then by their pattern id
|
||||
matches.sort_by(
|
||||
|(idxa, (sa, _)), (idxb, (sb, _))| {
|
||||
if sa != sb {
|
||||
sa.cmp(sb)
|
||||
} else {
|
||||
idxa.cmp(idxb)
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
// Select the matches (if some are overlapping) we want to keep
|
||||
let mut i = 0;
|
||||
let mut current_offset = 0;
|
||||
let mut splits = Vec::with_capacity(matches.len());
|
||||
while i < matches.len() {
|
||||
let (idx, (start, end)) = matches[i];
|
||||
|
||||
// current match is before the currentt offset, let's skip it
|
||||
if start < current_offset {
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Find out if we have overlapping neighbors. If so, we keep the one with the lowest
|
||||
// idx, and apply it, then continue. All others will be skipped since `current_offset`
|
||||
// will have been increased
|
||||
if i + 1 < matches.len() {
|
||||
if let Some((idx, (s, e))) = matches[i..]
|
||||
.iter()
|
||||
.take_while(|(_, (s, e))| *s < end && start < *e)
|
||||
.min() // Order on idx first
|
||||
.copied()
|
||||
{
|
||||
splits.push((idx, (s, e)));
|
||||
current_offset = e;
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// We didn't find overlapping neighbors, apply ourself
|
||||
splits.push((idx, (start, end)));
|
||||
current_offset = end;
|
||||
i += 1;
|
||||
}
|
||||
|
||||
// We also insert the splits that are inbetween the added tokens, to split the entire string
|
||||
let mut start_offset = 0;
|
||||
let mut splits = splits
|
||||
.into_iter()
|
||||
.flat_map(|(idx, (start, end))| {
|
||||
let mut splits = vec![];
|
||||
if start_offset < start {
|
||||
splits.push((None, (start_offset, start)));
|
||||
}
|
||||
splits.push((Some(idx), (start, end)));
|
||||
start_offset = end;
|
||||
|
||||
splits
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
if let Some((_, (_, end))) = splits.iter().last().copied() {
|
||||
if end < sentence.len() {
|
||||
splits.push((None, (end, sentence.len())));
|
||||
}
|
||||
}
|
||||
|
||||
if splits.is_empty() {
|
||||
vec![(NormalizedString::from(sentence), None)]
|
||||
} else {
|
||||
splits
|
||||
.into_iter()
|
||||
.map(|(idx, (start, end))| unsafe {
|
||||
let s = sentence.get_unchecked(start..end).to_owned();
|
||||
let normalized = NormalizedString::from(&s);
|
||||
|
||||
// Find out the associated AddedToken, and its id
|
||||
let id = if let Some(idx) = idx {
|
||||
let added = if idx >= self.special_tokens.len() {
|
||||
&self.added_tokens[idx - self.special_tokens.len()]
|
||||
} else {
|
||||
&self.special_tokens[idx]
|
||||
};
|
||||
|
||||
self.token_to_id(&added.content)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
(normalized, id)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
self.added_vocabulary
|
||||
.add_tokens(tokens, self.model.as_ref())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use super::{AddedToken, Tokenizer};
|
||||
use super::{added_vocabulary::AddedTokenWithId, Tokenizer};
|
||||
use crate::models::bpe::BPE;
|
||||
use serde::{
|
||||
self,
|
||||
@@ -9,18 +9,6 @@ use serde::{
|
||||
|
||||
static SERIALIZATION_VERSION: &str = "1.0";
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct AddedTokenWithId {
|
||||
/// The id assigned to this token
|
||||
id: u32,
|
||||
/// Whether this is a special token
|
||||
special: bool,
|
||||
|
||||
#[serde(flatten)]
|
||||
/// The target AddedToken
|
||||
token: AddedToken,
|
||||
}
|
||||
|
||||
impl Serialize for Tokenizer {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
@@ -36,18 +24,7 @@ impl Serialize for Tokenizer {
|
||||
tokenizer.serialize_field("padding", &self.padding)?;
|
||||
|
||||
// Added tokens
|
||||
let mut added_tokens = self
|
||||
.added_tokens_map_r
|
||||
.iter()
|
||||
.map(|(id, token)| AddedTokenWithId {
|
||||
id: *id,
|
||||
special: self.special_tokens_set.contains(&token.content),
|
||||
token: token.clone(),
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
// We need to have these added tokens ordered by ascending ID
|
||||
added_tokens.sort_unstable_by_key(|o| o.id);
|
||||
tokenizer.serialize_field("added_tokens", &added_tokens)?;
|
||||
tokenizer.serialize_field("added_tokens", &self.added_vocabulary)?;
|
||||
|
||||
// Then add our parts
|
||||
tokenizer.serialize_field("normalizer", &self.normalizer)?;
|
||||
@@ -141,6 +118,8 @@ impl<'de> Visitor<'de> for TokenizerVisitor {
|
||||
};
|
||||
}
|
||||
|
||||
// We take care of deserializing the added_tokens (instead of `AddedVocabulary` directly
|
||||
// because it let us check that associated IDs are still good, and warn the user otherwise
|
||||
for token in tokens {
|
||||
let tk = token.token.content.clone();
|
||||
if token.special {
|
||||
|
||||
Reference in New Issue
Block a user