mirror of
https://github.com/mii443/tokenizers.git
synced 2025-12-08 13:48:19 +00:00
Rust - Add AddedVocabulary + normalized option on AddedToken
This commit is contained in:
@@ -28,8 +28,8 @@ pub struct AddedToken {
|
|||||||
impl AddedToken {
|
impl AddedToken {
|
||||||
#[new]
|
#[new]
|
||||||
#[args(kwargs = "**")]
|
#[args(kwargs = "**")]
|
||||||
fn new(content: &str, kwargs: Option<&PyDict>) -> PyResult<Self> {
|
fn new(content: &str, is_special_token: bool, kwargs: Option<&PyDict>) -> PyResult<Self> {
|
||||||
let mut token = tk::tokenizer::AddedToken::from(content.to_owned());
|
let mut token = tk::tokenizer::AddedToken::from(content.to_owned(), is_special_token);
|
||||||
|
|
||||||
if let Some(kwargs) = kwargs {
|
if let Some(kwargs) = kwargs {
|
||||||
for (key, value) in kwargs {
|
for (key, value) in kwargs {
|
||||||
@@ -38,6 +38,7 @@ impl AddedToken {
|
|||||||
"single_word" => token = token.single_word(value.extract()?),
|
"single_word" => token = token.single_word(value.extract()?),
|
||||||
"lstrip" => token = token.lstrip(value.extract()?),
|
"lstrip" => token = token.lstrip(value.extract()?),
|
||||||
"rstrip" => token = token.rstrip(value.extract()?),
|
"rstrip" => token = token.rstrip(value.extract()?),
|
||||||
|
"normalized" => token = token.normalized(value.extract()?),
|
||||||
_ => println!("Ignored unknown kwarg option {}", key),
|
_ => println!("Ignored unknown kwarg option {}", key),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -65,6 +66,11 @@ impl AddedToken {
|
|||||||
fn get_single_word(&self) -> bool {
|
fn get_single_word(&self) -> bool {
|
||||||
self.token.single_word
|
self.token.single_word
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[getter]
|
||||||
|
fn get_normalized(&self) -> bool {
|
||||||
|
self.token.normalized
|
||||||
|
}
|
||||||
}
|
}
|
||||||
#[pyproto]
|
#[pyproto]
|
||||||
impl PyObjectProtocol for AddedToken {
|
impl PyObjectProtocol for AddedToken {
|
||||||
@@ -533,7 +539,7 @@ impl Tokenizer {
|
|||||||
self.tokenizer.token_to_id(token)
|
self.tokenizer.token_to_id(token)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn id_to_token(&self, id: u32) -> Option<String> {
|
fn id_to_token(&self, id: u32) -> Option<&str> {
|
||||||
self.tokenizer.id_to_token(id)
|
self.tokenizer.id_to_token(id)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -542,10 +548,7 @@ impl Tokenizer {
|
|||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|token| {
|
.map(|token| {
|
||||||
if let Ok(content) = token.extract::<String>() {
|
if let Ok(content) = token.extract::<String>() {
|
||||||
Ok(tk::tokenizer::AddedToken {
|
Ok(tk::tokenizer::AddedToken::from(content, false))
|
||||||
content,
|
|
||||||
..Default::default()
|
|
||||||
})
|
|
||||||
} else if let Ok(token) = token.extract::<PyRef<AddedToken>>() {
|
} else if let Ok(token) = token.extract::<PyRef<AddedToken>>() {
|
||||||
Ok(token.token.clone())
|
Ok(token.token.clone())
|
||||||
} else {
|
} else {
|
||||||
@@ -564,10 +567,7 @@ impl Tokenizer {
|
|||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|token| {
|
.map(|token| {
|
||||||
if let Ok(content) = token.extract::<String>() {
|
if let Ok(content) = token.extract::<String>() {
|
||||||
Ok(tk::tokenizer::AddedToken {
|
Ok(tk::tokenizer::AddedToken::from(content, true))
|
||||||
content,
|
|
||||||
..Default::default()
|
|
||||||
})
|
|
||||||
} else if let Ok(token) = token.extract::<PyRef<AddedToken>>() {
|
} else if let Ok(token) = token.extract::<PyRef<AddedToken>>() {
|
||||||
Ok(token.token.clone())
|
Ok(token.token.clone())
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -17,10 +17,9 @@ fn create_gpt2_tokenizer(bpe: BPE) -> Tokenizer {
|
|||||||
let mut tokenizer = Tokenizer::new(Box::new(bpe));
|
let mut tokenizer = Tokenizer::new(Box::new(bpe));
|
||||||
tokenizer.with_pre_tokenizer(Box::new(ByteLevel::default()));
|
tokenizer.with_pre_tokenizer(Box::new(ByteLevel::default()));
|
||||||
tokenizer.with_decoder(Box::new(ByteLevel::default()));
|
tokenizer.with_decoder(Box::new(ByteLevel::default()));
|
||||||
tokenizer.add_tokens(&[
|
tokenizer.add_tokens(&[AddedToken::from(String::from("ing"), false).single_word(false)]);
|
||||||
AddedToken::from(String::from("ing")).single_word(false),
|
tokenizer
|
||||||
AddedToken::from(String::from("[ENT]")).single_word(true),
|
.add_special_tokens(&[AddedToken::from(String::from("[ENT]"), true).single_word(true)]);
|
||||||
]);
|
|
||||||
tokenizer
|
tokenizer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -21,10 +21,9 @@ fn shell(matches: &ArgMatches) -> Result<()> {
|
|||||||
tokenizer.with_pre_tokenizer(Box::new(ByteLevel::default()));
|
tokenizer.with_pre_tokenizer(Box::new(ByteLevel::default()));
|
||||||
tokenizer.with_decoder(Box::new(ByteLevel::default()));
|
tokenizer.with_decoder(Box::new(ByteLevel::default()));
|
||||||
|
|
||||||
tokenizer.add_tokens(&[
|
tokenizer.add_tokens(&[AddedToken::from(String::from("ing"), false).single_word(false)]);
|
||||||
AddedToken::from(String::from("ing")).single_word(false),
|
tokenizer
|
||||||
AddedToken::from(String::from("[ENT]")).single_word(true),
|
.add_special_tokens(&[AddedToken::from(String::from("[ENT]"), true).single_word(true)]);
|
||||||
]);
|
|
||||||
|
|
||||||
let stdin = io::stdin();
|
let stdin = io::stdin();
|
||||||
let mut handle = stdin.lock();
|
let mut handle = stdin.lock();
|
||||||
|
|||||||
482
tokenizers/src/tokenizer/added_vocabulary.rs
Normal file
482
tokenizers/src/tokenizer/added_vocabulary.rs
Normal file
@@ -0,0 +1,482 @@
|
|||||||
|
use super::{Model, NormalizedString, Normalizer, Range};
|
||||||
|
use serde::{ser::SerializeSeq, Deserialize, Serialize, Serializer};
|
||||||
|
use std::collections::{HashMap, HashSet};
|
||||||
|
|
||||||
|
/// Represent a token added by the user on top of the existing Model vocabulary.
|
||||||
|
/// AddedToken can be configured to specify the behavior they should have in various situations
|
||||||
|
/// like:
|
||||||
|
/// - Whether they should only match single words
|
||||||
|
/// - Whether to include any whitespace on its left or right
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct AddedToken {
|
||||||
|
/// The content of the added token
|
||||||
|
pub content: String,
|
||||||
|
/// Whether this token must be a single word or can break words
|
||||||
|
pub single_word: bool,
|
||||||
|
/// Whether this token should strip whitespaces on its left
|
||||||
|
pub lstrip: bool,
|
||||||
|
/// Whether this token should strip whitespaces on its right
|
||||||
|
pub rstrip: bool,
|
||||||
|
/// Whether this token should be normalized
|
||||||
|
pub normalized: bool,
|
||||||
|
}
|
||||||
|
impl AddedToken {
|
||||||
|
/// Build this token from the given content, specifying if it is intented to be a
|
||||||
|
/// special token. Special tokens are not normalized by default.
|
||||||
|
pub fn from(content: String, special_token: bool) -> Self {
|
||||||
|
AddedToken {
|
||||||
|
content,
|
||||||
|
normalized: !special_token,
|
||||||
|
..Default::default()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/// Specify whether this token should only match on whole single words, and never
|
||||||
|
/// part of a word.
|
||||||
|
pub fn single_word(mut self, single_word: bool) -> Self {
|
||||||
|
self.single_word = single_word;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
/// Specify whether this token should include all the whitespaces on its left, in
|
||||||
|
/// order to strip them out.
|
||||||
|
pub fn lstrip(mut self, lstrip: bool) -> Self {
|
||||||
|
self.lstrip = lstrip;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
/// Specify whether this token should include all the whitespaces on its right, in
|
||||||
|
/// order to strip them out.
|
||||||
|
pub fn rstrip(mut self, rstrip: bool) -> Self {
|
||||||
|
self.rstrip = rstrip;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
/// Specify whether this token should be normalized, and/or match against its normalized
|
||||||
|
/// version in the input text.
|
||||||
|
pub fn normalized(mut self, normalized: bool) -> Self {
|
||||||
|
self.normalized = normalized;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
/// Retrive the pattern built for this token, according to all the specified parameters.
|
||||||
|
pub fn get_pattern(&self, normalizer: Option<&dyn Normalizer>) -> String {
|
||||||
|
let mut r = if self.single_word {
|
||||||
|
let first_b = self
|
||||||
|
.content
|
||||||
|
.chars()
|
||||||
|
.next()
|
||||||
|
.map(|c| {
|
||||||
|
if regex_syntax::is_word_character(c) {
|
||||||
|
r"\b"
|
||||||
|
} else {
|
||||||
|
""
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.unwrap();
|
||||||
|
let last_b = self
|
||||||
|
.content
|
||||||
|
.chars()
|
||||||
|
.last()
|
||||||
|
.map(|c| {
|
||||||
|
if regex_syntax::is_word_character(c) {
|
||||||
|
r"\b"
|
||||||
|
} else {
|
||||||
|
""
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Normalize the content
|
||||||
|
let mut content = NormalizedString::from(&self.content);
|
||||||
|
normalizer.map(|n| n.normalize(&mut content));
|
||||||
|
format!(r"{}{}{}", first_b, regex::escape(content.get()), last_b)
|
||||||
|
} else {
|
||||||
|
regex::escape(&self.content)
|
||||||
|
};
|
||||||
|
|
||||||
|
if self.lstrip && self.rstrip {
|
||||||
|
r = format!(r"(\s)?{}(\s)?", r);
|
||||||
|
} else if self.lstrip {
|
||||||
|
r = format!(r"(\s)?{}", r);
|
||||||
|
} else if self.rstrip {
|
||||||
|
r = format!(r"{}(\s)?", r);
|
||||||
|
}
|
||||||
|
|
||||||
|
r
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl Default for AddedToken {
|
||||||
|
fn default() -> Self {
|
||||||
|
AddedToken {
|
||||||
|
content: String::new(),
|
||||||
|
single_word: false,
|
||||||
|
lstrip: false,
|
||||||
|
rstrip: false,
|
||||||
|
normalized: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// We only want to hash on the content. AddedToken cannot be added multiple times with different
|
||||||
|
// options
|
||||||
|
impl std::hash::Hash for AddedToken {
|
||||||
|
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
||||||
|
self.content.hash(state);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl std::cmp::PartialEq for AddedToken {
|
||||||
|
fn eq(&self, other: &Self) -> bool {
|
||||||
|
self.content == other.content
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl std::cmp::Eq for AddedToken {}
|
||||||
|
|
||||||
|
type MatchingSet = (regex::RegexSet, Vec<u32>);
|
||||||
|
|
||||||
|
///
|
||||||
|
/// A vocabulary built on top of the Model
|
||||||
|
///
|
||||||
|
/// This provides a way to add new vocabulary to a Tokenizer that has already been trained,
|
||||||
|
/// in a previous process, maybe by someone else. This is especially interesting in the case
|
||||||
|
/// of fine-tunings, where we want to finetune a model while adding some new functionalities
|
||||||
|
/// using some new special tokens, or maybe add some tokens in the case of unknown tokens, etc.
|
||||||
|
///
|
||||||
|
/// One of the reasons we need to handle these tokens outside of the model is simply that
|
||||||
|
/// for many models, it is not possible to add new tokens after the training process. For example,
|
||||||
|
/// using BPE, the training process generates merges pairs along the vocabulary, and any token
|
||||||
|
/// in the vocabulary can be decomposed in other tokens, down to the original alphabet. If we
|
||||||
|
/// were to add new tokens after this training process, we couldn't make sure the merges pairs
|
||||||
|
/// exist as required.
|
||||||
|
///
|
||||||
|
pub(super) struct AddedVocabulary {
|
||||||
|
/// The size of the original vocabulary. This is what we use to determine the new
|
||||||
|
/// ids we need to generate
|
||||||
|
original_vocab_size: usize,
|
||||||
|
/// Contains the mapping from String to ID as the user intended it. This map
|
||||||
|
/// contains both special tokens and classic added tokens.
|
||||||
|
added_tokens_map: HashMap<String, u32>,
|
||||||
|
/// Contains the mapping from ID to AddedToken for all the added tokens, both special
|
||||||
|
/// and classic.
|
||||||
|
added_tokens_map_r: HashMap<u32, AddedToken>,
|
||||||
|
/// Contains only the classic AddedToken, in the specific order the user gave them.
|
||||||
|
added_tokens: Vec<AddedToken>,
|
||||||
|
/// Contains only the special AddedToken, in the specific order the user gave them.
|
||||||
|
special_tokens: Vec<AddedToken>,
|
||||||
|
/// A Set, containing all the special token for easy access while decoding. This let's
|
||||||
|
/// use remove them easily with an O(1) complexity.
|
||||||
|
special_tokens_set: HashSet<String>,
|
||||||
|
/// A RegexSet containing all the non-normalized patterns used to split on AddedTokens
|
||||||
|
split_re: MatchingSet,
|
||||||
|
/// A RegexSet containing all the normalized patterns used to split on AddedTokens
|
||||||
|
split_normalized_re: MatchingSet,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AddedVocabulary {
|
||||||
|
pub fn new(original_vocab_size: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
original_vocab_size,
|
||||||
|
added_tokens_map: HashMap::new(),
|
||||||
|
added_tokens_map_r: HashMap::new(),
|
||||||
|
added_tokens: vec![],
|
||||||
|
special_tokens: vec![],
|
||||||
|
special_tokens_set: HashSet::new(),
|
||||||
|
split_re: (regex::RegexSet::new::<_, &&str>(&[]).unwrap(), vec![]),
|
||||||
|
split_normalized_re: (regex::RegexSet::new::<_, &&str>(&[]).unwrap(), vec![]),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sets the original vocabulary size. We need this value to return IDs that
|
||||||
|
/// are shifted according to the original vocabulary.
|
||||||
|
pub fn update_original_vocab_size(&mut self, size: usize) {
|
||||||
|
self.original_vocab_size = size;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Size of the additional vocabulary
|
||||||
|
pub fn len(&self) -> usize {
|
||||||
|
self.added_tokens_map.len()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the additional vocabulary
|
||||||
|
pub fn get_vocab(&self) -> &HashMap<String, u32> {
|
||||||
|
&self.added_tokens_map
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the id matching one of our token if it exists
|
||||||
|
pub fn token_to_id(&self, token: &str) -> Option<&u32> {
|
||||||
|
self.added_tokens_map.get(token)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the token matching the given id if it exists
|
||||||
|
pub fn id_to_token(&self, id: u32) -> Option<&str> {
|
||||||
|
self.added_tokens_map_r.get(&id).map(|t| t.content.as_ref())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if a token is a special token
|
||||||
|
pub fn is_special_token(&self, token: &str) -> bool {
|
||||||
|
self.special_tokens_set.contains(token)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Add some special tokens to the vocabulary
|
||||||
|
pub fn add_special_tokens(
|
||||||
|
&mut self,
|
||||||
|
tokens: &[AddedToken],
|
||||||
|
model: &dyn Model,
|
||||||
|
normalizer: Option<&dyn Normalizer>,
|
||||||
|
) -> usize {
|
||||||
|
for token in tokens {
|
||||||
|
if !self.special_tokens_set.contains(&token.content) {
|
||||||
|
self.special_tokens.push(token.to_owned());
|
||||||
|
self.special_tokens_set.insert(token.content.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let added = self.add_tokens(&tokens, model, normalizer);
|
||||||
|
|
||||||
|
self.refresh_added_tokens(normalizer);
|
||||||
|
|
||||||
|
added
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Add some tokens to the vocabulary
|
||||||
|
pub fn add_tokens(
|
||||||
|
&mut self,
|
||||||
|
tokens: &[AddedToken],
|
||||||
|
model: &dyn Model,
|
||||||
|
normalizer: Option<&dyn Normalizer>,
|
||||||
|
) -> usize {
|
||||||
|
let mut ignored = 0;
|
||||||
|
for token in tokens {
|
||||||
|
if token.content.is_empty() {
|
||||||
|
ignored += 1;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let id = if let Some(id) = model.token_to_id(&token.content) {
|
||||||
|
ignored += 1;
|
||||||
|
id
|
||||||
|
} else {
|
||||||
|
let new_id = (self.original_vocab_size + self.added_tokens_map.len()) as u32;
|
||||||
|
self.added_tokens_map.insert(token.content.clone(), new_id);
|
||||||
|
|
||||||
|
if !self.special_tokens_set.contains(&token.content) {
|
||||||
|
self.added_tokens.push(token.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
new_id
|
||||||
|
};
|
||||||
|
|
||||||
|
// Update the current revert operation
|
||||||
|
self.added_tokens_map_r
|
||||||
|
.entry(id)
|
||||||
|
.and_modify(|t| *t = token.clone())
|
||||||
|
.or_insert_with(|| token.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
self.refresh_added_tokens(normalizer);
|
||||||
|
|
||||||
|
// Return the number of added tokens
|
||||||
|
tokens.len() - ignored
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Reconstruct our internal RegexSet when new tokens are added to the vocabulary.
|
||||||
|
///
|
||||||
|
/// We keep two different RegexSet, one that will take care of matching against the
|
||||||
|
/// non-normalized string, and one matching against the normalized one.
|
||||||
|
fn refresh_added_tokens(&mut self, normalizer: Option<&dyn Normalizer>) {
|
||||||
|
type TupleTokenId<'a> = (&'a AddedToken, u32);
|
||||||
|
let (normalized, non_normalized): (Vec<TupleTokenId>, Vec<TupleTokenId>) = self
|
||||||
|
.special_tokens
|
||||||
|
.iter()
|
||||||
|
.chain(self.added_tokens.iter())
|
||||||
|
// TODO: Fix this: special tokens that are part of the original vocabulary are
|
||||||
|
// not part of the `self.added_tokens_map` and so it crashes.
|
||||||
|
.map(|token| (token, self.added_tokens_map[&token.content]))
|
||||||
|
.partition(|(token, _)| token.normalized);
|
||||||
|
|
||||||
|
let (tokens, ids): (Vec<&AddedToken>, Vec<u32>) = non_normalized.into_iter().unzip();
|
||||||
|
self.split_re = (
|
||||||
|
regex::RegexSet::new(tokens.iter().map(|t| t.get_pattern(normalizer))).unwrap(),
|
||||||
|
ids,
|
||||||
|
);
|
||||||
|
|
||||||
|
let (tokens, ids): (Vec<&AddedToken>, Vec<u32>) = normalized.into_iter().unzip();
|
||||||
|
self.split_normalized_re = (
|
||||||
|
regex::RegexSet::new(tokens.iter().map(|t| t.get_pattern(normalizer))).unwrap(),
|
||||||
|
ids,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// TODO: Add doc string here
|
||||||
|
fn extract(
|
||||||
|
&self,
|
||||||
|
sentence: NormalizedString,
|
||||||
|
split_re: &MatchingSet,
|
||||||
|
) -> Vec<(NormalizedString, Option<u32>)> {
|
||||||
|
let mut matches = split_re
|
||||||
|
.0
|
||||||
|
.matches(sentence.get())
|
||||||
|
.into_iter()
|
||||||
|
.flat_map(|idx| {
|
||||||
|
regex::Regex::new(&split_re.0.patterns()[idx])
|
||||||
|
.unwrap()
|
||||||
|
.find_iter(sentence.get())
|
||||||
|
.map(|m| (idx, (m.start(), m.end())))
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
// We sort all the matches by their start and then by their pattern id
|
||||||
|
matches.sort_by(
|
||||||
|
|(idxa, (sa, _)), (idxb, (sb, _))| {
|
||||||
|
if sa != sb {
|
||||||
|
sa.cmp(sb)
|
||||||
|
} else {
|
||||||
|
idxa.cmp(idxb)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
// Select the matches (if some are overlapping) we want to keep
|
||||||
|
let mut i = 0;
|
||||||
|
let mut current_offset = 0;
|
||||||
|
let mut splits = Vec::with_capacity(matches.len());
|
||||||
|
while i < matches.len() {
|
||||||
|
let (idx, (start, end)) = matches[i];
|
||||||
|
|
||||||
|
// current match is before the currentt offset, let's skip it
|
||||||
|
if start < current_offset {
|
||||||
|
i += 1;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find out if we have overlapping neighbors. If so, we keep the one with the lowest
|
||||||
|
// idx, and apply it, then continue. All others will be skipped since `current_offset`
|
||||||
|
// will have been increased
|
||||||
|
if i + 1 < matches.len() {
|
||||||
|
if let Some((idx, (s, e))) = matches[i..]
|
||||||
|
.iter()
|
||||||
|
.take_while(|(_, (s, e))| *s < end && start < *e)
|
||||||
|
.min() // Order on idx first
|
||||||
|
.copied()
|
||||||
|
{
|
||||||
|
splits.push((idx, (s, e)));
|
||||||
|
current_offset = e;
|
||||||
|
i += 1;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// We didn't find overlapping neighbors, apply ourself
|
||||||
|
splits.push((idx, (start, end)));
|
||||||
|
current_offset = end;
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// We also insert the splits that are inbetween the added tokens, to split the entire string
|
||||||
|
let mut start_offset = 0;
|
||||||
|
let mut splits = splits
|
||||||
|
.into_iter()
|
||||||
|
.flat_map(|(idx, (start, end))| {
|
||||||
|
let mut splits = vec![];
|
||||||
|
if start_offset < start {
|
||||||
|
splits.push((None, (start_offset, start)));
|
||||||
|
}
|
||||||
|
splits.push((Some(idx), (start, end)));
|
||||||
|
start_offset = end;
|
||||||
|
|
||||||
|
splits
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
if let Some((_, (_, end))) = splits.iter().last().copied() {
|
||||||
|
if end < sentence.get().len() {
|
||||||
|
splits.push((None, (end, sentence.get().len())));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if splits.is_empty() {
|
||||||
|
vec![(sentence, None)]
|
||||||
|
} else {
|
||||||
|
splits
|
||||||
|
.into_iter()
|
||||||
|
.map(|(idx, (start, end))| {
|
||||||
|
// TODO: Check this works
|
||||||
|
let normalized = sentence
|
||||||
|
.slice_bytes(Range::Normalized(start..end))
|
||||||
|
.expect("Error while extracting normalized Range");
|
||||||
|
|
||||||
|
// Find out the associated AddedToken, and its id
|
||||||
|
let id = idx.map(|idx| split_re.1[idx]);
|
||||||
|
|
||||||
|
(normalized, id)
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extract the additional vocabulary from the given sentence, normalizing it along the way.
|
||||||
|
///
|
||||||
|
/// Some tokens should match against their normalized representation, as well as the
|
||||||
|
/// non-normalized one. For example, when we expect to extract the token `yesterday` in the
|
||||||
|
/// input sentence `I read a book Yesterday`, if the normalizer is supposed to lowercase
|
||||||
|
/// everything, we expect a match.
|
||||||
|
///
|
||||||
|
/// This method returns a `Vec` of `(NormalizedString, Option<u32>)`, where the optional `u32`
|
||||||
|
/// contains the relevant ID if this is an additional token.
|
||||||
|
pub fn extract_and_normalize(
|
||||||
|
&self,
|
||||||
|
normalizer: Option<&dyn Normalizer>,
|
||||||
|
sentence: &str,
|
||||||
|
) -> Vec<(NormalizedString, Option<u32>)> {
|
||||||
|
// 1. We extract all the non-normalized tokens from the non-normalized string
|
||||||
|
let pieces = self.extract(NormalizedString::from(sentence), &self.split_re);
|
||||||
|
|
||||||
|
// 2. Then extract the normalized tokens from the normalized pieces of the string
|
||||||
|
pieces
|
||||||
|
.into_iter()
|
||||||
|
.flat_map(|(mut normalized, id)| {
|
||||||
|
if id.is_some() {
|
||||||
|
// If the piece has an associated ID, we already extracted something,
|
||||||
|
// so we just return it
|
||||||
|
vec![(normalized, id)]
|
||||||
|
} else {
|
||||||
|
// Otherwise, we need to normalized the string, and then proceed to extracting
|
||||||
|
normalizer.map(|n| n.normalize(&mut normalized));
|
||||||
|
self.extract(normalized, &self.split_normalized_re)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub(super) struct AddedTokenWithId {
|
||||||
|
/// The id assigned to this token
|
||||||
|
pub id: u32,
|
||||||
|
/// Whether this is a special token
|
||||||
|
pub special: bool,
|
||||||
|
|
||||||
|
#[serde(flatten)]
|
||||||
|
/// The target AddedToken
|
||||||
|
pub token: AddedToken,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Serialize for AddedVocabulary {
|
||||||
|
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||||
|
where
|
||||||
|
S: Serializer,
|
||||||
|
{
|
||||||
|
let mut vocabulary = serializer.serialize_seq(Some(self.added_tokens_map.len()))?;
|
||||||
|
|
||||||
|
let mut added_tokens = self
|
||||||
|
.added_tokens_map_r
|
||||||
|
.iter()
|
||||||
|
.map(|(id, token)| AddedTokenWithId {
|
||||||
|
id: *id,
|
||||||
|
special: self.special_tokens_set.contains(&token.content),
|
||||||
|
token: token.clone(),
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
// We need to have these added tokens ordered by ascending ID
|
||||||
|
added_tokens.sort_unstable_by_key(|o| o.id);
|
||||||
|
|
||||||
|
for token in added_tokens {
|
||||||
|
vocabulary.serialize_element(&token)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
vocabulary.end()
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -211,6 +211,7 @@ impl std::str::FromStr for Tokenizer {
|
|||||||
impl Tokenizer {
|
impl Tokenizer {
|
||||||
/// Instantiate a new Tokenizer, with the given Model
|
/// Instantiate a new Tokenizer, with the given Model
|
||||||
pub fn new(model: Box<dyn Model>) -> Self {
|
pub fn new(model: Box<dyn Model>) -> Self {
|
||||||
|
let original_vocab_size = model.get_vocab_size();
|
||||||
Tokenizer {
|
Tokenizer {
|
||||||
normalizer: None,
|
normalizer: None,
|
||||||
pre_tokenizer: None,
|
pre_tokenizer: None,
|
||||||
@@ -218,7 +219,7 @@ impl Tokenizer {
|
|||||||
post_processor: None,
|
post_processor: None,
|
||||||
decoder: None,
|
decoder: None,
|
||||||
|
|
||||||
added_vocabulary: AddedVocabulary::new(),
|
added_vocabulary: AddedVocabulary::new(original_vocab_size),
|
||||||
|
|
||||||
truncation: None,
|
truncation: None,
|
||||||
padding: None,
|
padding: None,
|
||||||
@@ -302,6 +303,8 @@ impl Tokenizer {
|
|||||||
/// Set the model
|
/// Set the model
|
||||||
pub fn with_model(&mut self, model: Box<dyn Model>) -> &Self {
|
pub fn with_model(&mut self, model: Box<dyn Model>) -> &Self {
|
||||||
self.model = model;
|
self.model = model;
|
||||||
|
self.added_vocabulary
|
||||||
|
.update_original_vocab_size(self.model.get_vocab_size());
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -389,16 +392,14 @@ impl Tokenizer {
|
|||||||
pub fn normalize(&self, sentence: &str) -> Result<NormalizedString> {
|
pub fn normalize(&self, sentence: &str) -> Result<NormalizedString> {
|
||||||
let mut normalized = self
|
let mut normalized = self
|
||||||
.added_vocabulary
|
.added_vocabulary
|
||||||
.extract(sentence)
|
.extract_and_normalize(self.normalizer.as_deref(), sentence)
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|(sentence, id)| -> Result<NormalizedString> {
|
.map(|(mut sentence, id)| -> Result<NormalizedString> {
|
||||||
if id.is_some() {
|
if id.is_some() {
|
||||||
Ok(sentence)
|
Ok(sentence)
|
||||||
} else {
|
} else {
|
||||||
let mut normalized = self.do_normalize(sentence)?;
|
self.pre_tokenize(&mut sentence)?;
|
||||||
let _ = self.pre_tokenize(&mut normalized)?;
|
Ok(sentence)
|
||||||
|
|
||||||
Ok(normalized)
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
@@ -421,28 +422,30 @@ impl Tokenizer {
|
|||||||
|
|
||||||
let mut sequence_encodings = vec![];
|
let mut sequence_encodings = vec![];
|
||||||
for subseq in sequence {
|
for subseq in sequence {
|
||||||
let results = self.added_vocabulary.extract(&subseq).into_iter().map(
|
let results = self
|
||||||
|(sentence, id)| -> Result<(Encoding, NormalizedString)> {
|
.added_vocabulary
|
||||||
|
.extract_and_normalize(self.normalizer.as_deref(), &subseq)
|
||||||
|
.into_iter()
|
||||||
|
.map(
|
||||||
|
|(mut normalized, id)| -> Result<(Encoding, NormalizedString)> {
|
||||||
if let Some(id) = id {
|
if let Some(id) = id {
|
||||||
Ok((
|
Ok((
|
||||||
Encoding::new(
|
Encoding::new(
|
||||||
vec![id],
|
vec![id],
|
||||||
vec![type_id],
|
vec![type_id],
|
||||||
vec![sentence.get().to_owned()],
|
vec![normalized.get().to_owned()],
|
||||||
vec![Some(0)],
|
vec![Some(0)],
|
||||||
vec![(0, sentence.len())],
|
vec![(0, normalized.len())],
|
||||||
vec![0],
|
vec![0],
|
||||||
vec![1],
|
vec![1],
|
||||||
vec![],
|
vec![],
|
||||||
),
|
),
|
||||||
sentence,
|
normalized,
|
||||||
))
|
))
|
||||||
} else {
|
} else {
|
||||||
// 1. Normalization
|
// 1. Pre tokenization
|
||||||
let mut normalized = self.do_normalize(sentence)?;
|
|
||||||
// 2. Pre tokenization
|
|
||||||
let pre_tokenized = self.pre_tokenize(&mut normalized)?;
|
let pre_tokenized = self.pre_tokenize(&mut normalized)?;
|
||||||
// 3. Model
|
// 2. Model
|
||||||
let tokens = self.model.tokenize(pre_tokenized)?;
|
let tokens = self.model.tokenize(pre_tokenized)?;
|
||||||
let encoding = Encoding::from_tokens(tokens, type_id);
|
let encoding = Encoding::from_tokens(tokens, type_id);
|
||||||
|
|
||||||
@@ -675,6 +678,8 @@ impl Tokenizer {
|
|||||||
|
|
||||||
let (model, special_tokens) = trainer.train(words)?;
|
let (model, special_tokens) = trainer.train(words)?;
|
||||||
self.model = model;
|
self.model = model;
|
||||||
|
self.added_vocabulary
|
||||||
|
.update_original_vocab_size(self.model.get_vocab_size());
|
||||||
self.add_special_tokens(&special_tokens);
|
self.add_special_tokens(&special_tokens);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
@@ -752,13 +757,16 @@ impl Tokenizer {
|
|||||||
/// Register the given tokens as special tokens. This is especially useful for removing
|
/// Register the given tokens as special tokens. This is especially useful for removing
|
||||||
/// these special tokens while decoding
|
/// these special tokens while decoding
|
||||||
pub fn add_special_tokens(&mut self, tokens: &[AddedToken]) -> usize {
|
pub fn add_special_tokens(&mut self, tokens: &[AddedToken]) -> usize {
|
||||||
self.added_vocabulary
|
self.added_vocabulary.add_special_tokens(
|
||||||
.add_special_tokens(tokens, self.model.as_ref())
|
tokens,
|
||||||
|
self.model.as_ref(),
|
||||||
|
self.normalizer.as_deref(),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Add the given tokens to the added vocabulary
|
/// Add the given tokens to the added vocabulary
|
||||||
pub fn add_tokens(&mut self, tokens: &[AddedToken]) -> usize {
|
pub fn add_tokens(&mut self, tokens: &[AddedToken]) -> usize {
|
||||||
self.added_vocabulary
|
self.added_vocabulary
|
||||||
.add_tokens(tokens, self.model.as_ref())
|
.add_tokens(tokens, self.model.as_ref(), self.normalizer.as_deref())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,8 +9,8 @@ fn add_tokens() {
|
|||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
tokenizer.add_special_tokens(&[
|
tokenizer.add_special_tokens(&[
|
||||||
AddedToken::from("<cls>".into()),
|
AddedToken::from("<cls>".into(), true),
|
||||||
AddedToken::from("<sep>".into())
|
AddedToken::from("<sep>".into(), true)
|
||||||
]),
|
]),
|
||||||
2
|
2
|
||||||
);
|
);
|
||||||
@@ -19,8 +19,8 @@ fn add_tokens() {
|
|||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
tokenizer.add_tokens(&[
|
tokenizer.add_tokens(&[
|
||||||
AddedToken::from("hello".into()),
|
AddedToken::from("hello".into(), false),
|
||||||
AddedToken::from("world".into())
|
AddedToken::from("world".into(), false)
|
||||||
]),
|
]),
|
||||||
2
|
2
|
||||||
);
|
);
|
||||||
@@ -31,7 +31,7 @@ fn add_tokens() {
|
|||||||
#[test]
|
#[test]
|
||||||
fn lstrip_tokens() {
|
fn lstrip_tokens() {
|
||||||
let mut tokenizer = get_byte_level(true, false);
|
let mut tokenizer = get_byte_level(true, false);
|
||||||
tokenizer.add_special_tokens(&[AddedToken::from("<mask>".into()).lstrip(true)]);
|
tokenizer.add_special_tokens(&[AddedToken::from("<mask>".into(), true).lstrip(true)]);
|
||||||
|
|
||||||
let input = "I saw a <mask> 😺";
|
let input = "I saw a <mask> 😺";
|
||||||
let output = tokenizer.encode(input, false).unwrap();
|
let output = tokenizer.encode(input, false).unwrap();
|
||||||
@@ -49,7 +49,7 @@ fn lstrip_tokens() {
|
|||||||
#[test]
|
#[test]
|
||||||
fn rstrip_tokens() {
|
fn rstrip_tokens() {
|
||||||
let mut tokenizer = get_byte_level(false, false);
|
let mut tokenizer = get_byte_level(false, false);
|
||||||
tokenizer.add_special_tokens(&[AddedToken::from("<mask>".into()).rstrip(true)]);
|
tokenizer.add_special_tokens(&[AddedToken::from("<mask>".into(), true).rstrip(true)]);
|
||||||
|
|
||||||
let input = "I saw a <mask> 😺";
|
let input = "I saw a <mask> 😺";
|
||||||
let output = tokenizer.encode(input, false).unwrap();
|
let output = tokenizer.encode(input, false).unwrap();
|
||||||
@@ -62,7 +62,7 @@ fn rstrip_tokens() {
|
|||||||
// When `add_prefix_space = true` rstrip cannot work as a prefix space is added
|
// When `add_prefix_space = true` rstrip cannot work as a prefix space is added
|
||||||
// to the next token
|
// to the next token
|
||||||
let mut tokenizer = get_byte_level(true, false);
|
let mut tokenizer = get_byte_level(true, false);
|
||||||
tokenizer.add_special_tokens(&[AddedToken::from("<mask>".into()).rstrip(true)]);
|
tokenizer.add_special_tokens(&[AddedToken::from("<mask>".into(), true).rstrip(true)]);
|
||||||
|
|
||||||
let input = "I saw a <mask> 😺";
|
let input = "I saw a <mask> 😺";
|
||||||
let output = tokenizer.encode(input, false).unwrap();
|
let output = tokenizer.encode(input, false).unwrap();
|
||||||
@@ -77,7 +77,7 @@ fn rstrip_tokens() {
|
|||||||
fn single_word_tokens() {
|
fn single_word_tokens() {
|
||||||
// If `single_word = true` it shouldn't split `dancing`
|
// If `single_word = true` it shouldn't split `dancing`
|
||||||
let mut tokenizer = get_byte_level(false, false);
|
let mut tokenizer = get_byte_level(false, false);
|
||||||
tokenizer.add_special_tokens(&[AddedToken::from("ing".into()).single_word(true)]);
|
tokenizer.add_special_tokens(&[AddedToken::from("ing".into(), true).single_word(true)]);
|
||||||
|
|
||||||
let input = "I like dancing";
|
let input = "I like dancing";
|
||||||
let output = tokenizer.encode(input, false).unwrap();
|
let output = tokenizer.encode(input, false).unwrap();
|
||||||
@@ -86,7 +86,7 @@ fn single_word_tokens() {
|
|||||||
|
|
||||||
// If `single_word = false` it should split `dancing`
|
// If `single_word = false` it should split `dancing`
|
||||||
let mut tokenizer = get_byte_level(false, false);
|
let mut tokenizer = get_byte_level(false, false);
|
||||||
tokenizer.add_special_tokens(&[AddedToken::from("ing".into()).single_word(false)]);
|
tokenizer.add_special_tokens(&[AddedToken::from("ing".into(), true).single_word(false)]);
|
||||||
|
|
||||||
let input = "I like dancing";
|
let input = "I like dancing";
|
||||||
let output = tokenizer.encode(input, false).unwrap();
|
let output = tokenizer.encode(input, false).unwrap();
|
||||||
@@ -98,9 +98,9 @@ fn single_word_tokens() {
|
|||||||
fn overlapping_tokens() {
|
fn overlapping_tokens() {
|
||||||
let mut tokenizer = get_byte_level(false, false);
|
let mut tokenizer = get_byte_level(false, false);
|
||||||
|
|
||||||
tokenizer.add_special_tokens(&[AddedToken::from("danc".into())]);
|
tokenizer.add_special_tokens(&[AddedToken::from("danc".into(), true)]);
|
||||||
tokenizer.add_special_tokens(&[AddedToken::from("nci".into())]);
|
tokenizer.add_special_tokens(&[AddedToken::from("nci".into(), true)]);
|
||||||
tokenizer.add_special_tokens(&[AddedToken::from("ing".into())]);
|
tokenizer.add_special_tokens(&[AddedToken::from("ing".into(), true)]);
|
||||||
|
|
||||||
let input = "I like dancing";
|
let input = "I like dancing";
|
||||||
let output = tokenizer.encode(input, false).unwrap();
|
let output = tokenizer.encode(input, false).unwrap();
|
||||||
@@ -109,10 +109,10 @@ fn overlapping_tokens() {
|
|||||||
|
|
||||||
let mut tokenizer = get_byte_level(false, false);
|
let mut tokenizer = get_byte_level(false, false);
|
||||||
|
|
||||||
tokenizer.add_special_tokens(&[AddedToken::from("nci".into())]);
|
tokenizer.add_special_tokens(&[AddedToken::from("nci".into(), true)]);
|
||||||
tokenizer.add_special_tokens(&[AddedToken::from("danc".into())]);
|
tokenizer.add_special_tokens(&[AddedToken::from("danc".into(), true)]);
|
||||||
tokenizer.add_special_tokens(&[AddedToken::from("ing".into())]);
|
tokenizer.add_special_tokens(&[AddedToken::from("ing".into(), true)]);
|
||||||
tokenizer.add_special_tokens(&[AddedToken::from("ike".into())]);
|
tokenizer.add_special_tokens(&[AddedToken::from("ike".into(), true)]);
|
||||||
|
|
||||||
let output = tokenizer.encode(input, false).unwrap();
|
let output = tokenizer.encode(input, false).unwrap();
|
||||||
|
|
||||||
|
|||||||
@@ -158,7 +158,7 @@ fn split_on_added_tokens_bert() {
|
|||||||
let input = "Yesterday I saw a [MASK] far away";
|
let input = "Yesterday I saw a [MASK] far away";
|
||||||
|
|
||||||
let mut tokenizer = get_bert();
|
let mut tokenizer = get_bert();
|
||||||
tokenizer.add_special_tokens(&[AddedToken::from("[MASK]".into())]);
|
tokenizer.add_special_tokens(&[AddedToken::from("[MASK]".into(), true)]);
|
||||||
let output = tokenizer.encode(input, false).unwrap();
|
let output = tokenizer.encode(input, false).unwrap();
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
|
|||||||
Reference in New Issue
Block a user