mirror of
https://github.com/mii443/tokenizers.git
synced 2025-12-07 21:28:19 +00:00
Rust - Add AddedVocabulary + normalized option on AddedToken
This commit is contained in:
@@ -28,8 +28,8 @@ pub struct AddedToken {
|
||||
impl AddedToken {
|
||||
#[new]
|
||||
#[args(kwargs = "**")]
|
||||
fn new(content: &str, kwargs: Option<&PyDict>) -> PyResult<Self> {
|
||||
let mut token = tk::tokenizer::AddedToken::from(content.to_owned());
|
||||
fn new(content: &str, is_special_token: bool, kwargs: Option<&PyDict>) -> PyResult<Self> {
|
||||
let mut token = tk::tokenizer::AddedToken::from(content.to_owned(), is_special_token);
|
||||
|
||||
if let Some(kwargs) = kwargs {
|
||||
for (key, value) in kwargs {
|
||||
@@ -38,6 +38,7 @@ impl AddedToken {
|
||||
"single_word" => token = token.single_word(value.extract()?),
|
||||
"lstrip" => token = token.lstrip(value.extract()?),
|
||||
"rstrip" => token = token.rstrip(value.extract()?),
|
||||
"normalized" => token = token.normalized(value.extract()?),
|
||||
_ => println!("Ignored unknown kwarg option {}", key),
|
||||
}
|
||||
}
|
||||
@@ -65,6 +66,11 @@ impl AddedToken {
|
||||
fn get_single_word(&self) -> bool {
|
||||
self.token.single_word
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_normalized(&self) -> bool {
|
||||
self.token.normalized
|
||||
}
|
||||
}
|
||||
#[pyproto]
|
||||
impl PyObjectProtocol for AddedToken {
|
||||
@@ -533,7 +539,7 @@ impl Tokenizer {
|
||||
self.tokenizer.token_to_id(token)
|
||||
}
|
||||
|
||||
fn id_to_token(&self, id: u32) -> Option<String> {
|
||||
fn id_to_token(&self, id: u32) -> Option<&str> {
|
||||
self.tokenizer.id_to_token(id)
|
||||
}
|
||||
|
||||
@@ -542,10 +548,7 @@ impl Tokenizer {
|
||||
.into_iter()
|
||||
.map(|token| {
|
||||
if let Ok(content) = token.extract::<String>() {
|
||||
Ok(tk::tokenizer::AddedToken {
|
||||
content,
|
||||
..Default::default()
|
||||
})
|
||||
Ok(tk::tokenizer::AddedToken::from(content, false))
|
||||
} else if let Ok(token) = token.extract::<PyRef<AddedToken>>() {
|
||||
Ok(token.token.clone())
|
||||
} else {
|
||||
@@ -564,10 +567,7 @@ impl Tokenizer {
|
||||
.into_iter()
|
||||
.map(|token| {
|
||||
if let Ok(content) = token.extract::<String>() {
|
||||
Ok(tk::tokenizer::AddedToken {
|
||||
content,
|
||||
..Default::default()
|
||||
})
|
||||
Ok(tk::tokenizer::AddedToken::from(content, true))
|
||||
} else if let Ok(token) = token.extract::<PyRef<AddedToken>>() {
|
||||
Ok(token.token.clone())
|
||||
} else {
|
||||
|
||||
@@ -17,10 +17,9 @@ fn create_gpt2_tokenizer(bpe: BPE) -> Tokenizer {
|
||||
let mut tokenizer = Tokenizer::new(Box::new(bpe));
|
||||
tokenizer.with_pre_tokenizer(Box::new(ByteLevel::default()));
|
||||
tokenizer.with_decoder(Box::new(ByteLevel::default()));
|
||||
tokenizer.add_tokens(&[
|
||||
AddedToken::from(String::from("ing")).single_word(false),
|
||||
AddedToken::from(String::from("[ENT]")).single_word(true),
|
||||
]);
|
||||
tokenizer.add_tokens(&[AddedToken::from(String::from("ing"), false).single_word(false)]);
|
||||
tokenizer
|
||||
.add_special_tokens(&[AddedToken::from(String::from("[ENT]"), true).single_word(true)]);
|
||||
tokenizer
|
||||
}
|
||||
|
||||
|
||||
@@ -21,10 +21,9 @@ fn shell(matches: &ArgMatches) -> Result<()> {
|
||||
tokenizer.with_pre_tokenizer(Box::new(ByteLevel::default()));
|
||||
tokenizer.with_decoder(Box::new(ByteLevel::default()));
|
||||
|
||||
tokenizer.add_tokens(&[
|
||||
AddedToken::from(String::from("ing")).single_word(false),
|
||||
AddedToken::from(String::from("[ENT]")).single_word(true),
|
||||
]);
|
||||
tokenizer.add_tokens(&[AddedToken::from(String::from("ing"), false).single_word(false)]);
|
||||
tokenizer
|
||||
.add_special_tokens(&[AddedToken::from(String::from("[ENT]"), true).single_word(true)]);
|
||||
|
||||
let stdin = io::stdin();
|
||||
let mut handle = stdin.lock();
|
||||
|
||||
482
tokenizers/src/tokenizer/added_vocabulary.rs
Normal file
482
tokenizers/src/tokenizer/added_vocabulary.rs
Normal file
@@ -0,0 +1,482 @@
|
||||
use super::{Model, NormalizedString, Normalizer, Range};
|
||||
use serde::{ser::SerializeSeq, Deserialize, Serialize, Serializer};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
/// Represent a token added by the user on top of the existing Model vocabulary.
|
||||
/// AddedToken can be configured to specify the behavior they should have in various situations
|
||||
/// like:
|
||||
/// - Whether they should only match single words
|
||||
/// - Whether to include any whitespace on its left or right
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AddedToken {
|
||||
/// The content of the added token
|
||||
pub content: String,
|
||||
/// Whether this token must be a single word or can break words
|
||||
pub single_word: bool,
|
||||
/// Whether this token should strip whitespaces on its left
|
||||
pub lstrip: bool,
|
||||
/// Whether this token should strip whitespaces on its right
|
||||
pub rstrip: bool,
|
||||
/// Whether this token should be normalized
|
||||
pub normalized: bool,
|
||||
}
|
||||
impl AddedToken {
|
||||
/// Build this token from the given content, specifying if it is intented to be a
|
||||
/// special token. Special tokens are not normalized by default.
|
||||
pub fn from(content: String, special_token: bool) -> Self {
|
||||
AddedToken {
|
||||
content,
|
||||
normalized: !special_token,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
/// Specify whether this token should only match on whole single words, and never
|
||||
/// part of a word.
|
||||
pub fn single_word(mut self, single_word: bool) -> Self {
|
||||
self.single_word = single_word;
|
||||
self
|
||||
}
|
||||
/// Specify whether this token should include all the whitespaces on its left, in
|
||||
/// order to strip them out.
|
||||
pub fn lstrip(mut self, lstrip: bool) -> Self {
|
||||
self.lstrip = lstrip;
|
||||
self
|
||||
}
|
||||
/// Specify whether this token should include all the whitespaces on its right, in
|
||||
/// order to strip them out.
|
||||
pub fn rstrip(mut self, rstrip: bool) -> Self {
|
||||
self.rstrip = rstrip;
|
||||
self
|
||||
}
|
||||
/// Specify whether this token should be normalized, and/or match against its normalized
|
||||
/// version in the input text.
|
||||
pub fn normalized(mut self, normalized: bool) -> Self {
|
||||
self.normalized = normalized;
|
||||
self
|
||||
}
|
||||
/// Retrive the pattern built for this token, according to all the specified parameters.
|
||||
pub fn get_pattern(&self, normalizer: Option<&dyn Normalizer>) -> String {
|
||||
let mut r = if self.single_word {
|
||||
let first_b = self
|
||||
.content
|
||||
.chars()
|
||||
.next()
|
||||
.map(|c| {
|
||||
if regex_syntax::is_word_character(c) {
|
||||
r"\b"
|
||||
} else {
|
||||
""
|
||||
}
|
||||
})
|
||||
.unwrap();
|
||||
let last_b = self
|
||||
.content
|
||||
.chars()
|
||||
.last()
|
||||
.map(|c| {
|
||||
if regex_syntax::is_word_character(c) {
|
||||
r"\b"
|
||||
} else {
|
||||
""
|
||||
}
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
// Normalize the content
|
||||
let mut content = NormalizedString::from(&self.content);
|
||||
normalizer.map(|n| n.normalize(&mut content));
|
||||
format!(r"{}{}{}", first_b, regex::escape(content.get()), last_b)
|
||||
} else {
|
||||
regex::escape(&self.content)
|
||||
};
|
||||
|
||||
if self.lstrip && self.rstrip {
|
||||
r = format!(r"(\s)?{}(\s)?", r);
|
||||
} else if self.lstrip {
|
||||
r = format!(r"(\s)?{}", r);
|
||||
} else if self.rstrip {
|
||||
r = format!(r"{}(\s)?", r);
|
||||
}
|
||||
|
||||
r
|
||||
}
|
||||
}
|
||||
impl Default for AddedToken {
|
||||
fn default() -> Self {
|
||||
AddedToken {
|
||||
content: String::new(),
|
||||
single_word: false,
|
||||
lstrip: false,
|
||||
rstrip: false,
|
||||
normalized: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
// We only want to hash on the content. AddedToken cannot be added multiple times with different
|
||||
// options
|
||||
impl std::hash::Hash for AddedToken {
|
||||
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
||||
self.content.hash(state);
|
||||
}
|
||||
}
|
||||
impl std::cmp::PartialEq for AddedToken {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.content == other.content
|
||||
}
|
||||
}
|
||||
impl std::cmp::Eq for AddedToken {}
|
||||
|
||||
type MatchingSet = (regex::RegexSet, Vec<u32>);
|
||||
|
||||
///
|
||||
/// A vocabulary built on top of the Model
|
||||
///
|
||||
/// This provides a way to add new vocabulary to a Tokenizer that has already been trained,
|
||||
/// in a previous process, maybe by someone else. This is especially interesting in the case
|
||||
/// of fine-tunings, where we want to finetune a model while adding some new functionalities
|
||||
/// using some new special tokens, or maybe add some tokens in the case of unknown tokens, etc.
|
||||
///
|
||||
/// One of the reasons we need to handle these tokens outside of the model is simply that
|
||||
/// for many models, it is not possible to add new tokens after the training process. For example,
|
||||
/// using BPE, the training process generates merges pairs along the vocabulary, and any token
|
||||
/// in the vocabulary can be decomposed in other tokens, down to the original alphabet. If we
|
||||
/// were to add new tokens after this training process, we couldn't make sure the merges pairs
|
||||
/// exist as required.
|
||||
///
|
||||
pub(super) struct AddedVocabulary {
|
||||
/// The size of the original vocabulary. This is what we use to determine the new
|
||||
/// ids we need to generate
|
||||
original_vocab_size: usize,
|
||||
/// Contains the mapping from String to ID as the user intended it. This map
|
||||
/// contains both special tokens and classic added tokens.
|
||||
added_tokens_map: HashMap<String, u32>,
|
||||
/// Contains the mapping from ID to AddedToken for all the added tokens, both special
|
||||
/// and classic.
|
||||
added_tokens_map_r: HashMap<u32, AddedToken>,
|
||||
/// Contains only the classic AddedToken, in the specific order the user gave them.
|
||||
added_tokens: Vec<AddedToken>,
|
||||
/// Contains only the special AddedToken, in the specific order the user gave them.
|
||||
special_tokens: Vec<AddedToken>,
|
||||
/// A Set, containing all the special token for easy access while decoding. This let's
|
||||
/// use remove them easily with an O(1) complexity.
|
||||
special_tokens_set: HashSet<String>,
|
||||
/// A RegexSet containing all the non-normalized patterns used to split on AddedTokens
|
||||
split_re: MatchingSet,
|
||||
/// A RegexSet containing all the normalized patterns used to split on AddedTokens
|
||||
split_normalized_re: MatchingSet,
|
||||
}
|
||||
|
||||
impl AddedVocabulary {
|
||||
pub fn new(original_vocab_size: usize) -> Self {
|
||||
Self {
|
||||
original_vocab_size,
|
||||
added_tokens_map: HashMap::new(),
|
||||
added_tokens_map_r: HashMap::new(),
|
||||
added_tokens: vec![],
|
||||
special_tokens: vec![],
|
||||
special_tokens_set: HashSet::new(),
|
||||
split_re: (regex::RegexSet::new::<_, &&str>(&[]).unwrap(), vec![]),
|
||||
split_normalized_re: (regex::RegexSet::new::<_, &&str>(&[]).unwrap(), vec![]),
|
||||
}
|
||||
}
|
||||
|
||||
/// Sets the original vocabulary size. We need this value to return IDs that
|
||||
/// are shifted according to the original vocabulary.
|
||||
pub fn update_original_vocab_size(&mut self, size: usize) {
|
||||
self.original_vocab_size = size;
|
||||
}
|
||||
|
||||
/// Size of the additional vocabulary
|
||||
pub fn len(&self) -> usize {
|
||||
self.added_tokens_map.len()
|
||||
}
|
||||
|
||||
/// Get the additional vocabulary
|
||||
pub fn get_vocab(&self) -> &HashMap<String, u32> {
|
||||
&self.added_tokens_map
|
||||
}
|
||||
|
||||
/// Get the id matching one of our token if it exists
|
||||
pub fn token_to_id(&self, token: &str) -> Option<&u32> {
|
||||
self.added_tokens_map.get(token)
|
||||
}
|
||||
|
||||
/// Get the token matching the given id if it exists
|
||||
pub fn id_to_token(&self, id: u32) -> Option<&str> {
|
||||
self.added_tokens_map_r.get(&id).map(|t| t.content.as_ref())
|
||||
}
|
||||
|
||||
/// Check if a token is a special token
|
||||
pub fn is_special_token(&self, token: &str) -> bool {
|
||||
self.special_tokens_set.contains(token)
|
||||
}
|
||||
|
||||
/// Add some special tokens to the vocabulary
|
||||
pub fn add_special_tokens(
|
||||
&mut self,
|
||||
tokens: &[AddedToken],
|
||||
model: &dyn Model,
|
||||
normalizer: Option<&dyn Normalizer>,
|
||||
) -> usize {
|
||||
for token in tokens {
|
||||
if !self.special_tokens_set.contains(&token.content) {
|
||||
self.special_tokens.push(token.to_owned());
|
||||
self.special_tokens_set.insert(token.content.clone());
|
||||
}
|
||||
}
|
||||
let added = self.add_tokens(&tokens, model, normalizer);
|
||||
|
||||
self.refresh_added_tokens(normalizer);
|
||||
|
||||
added
|
||||
}
|
||||
|
||||
/// Add some tokens to the vocabulary
|
||||
pub fn add_tokens(
|
||||
&mut self,
|
||||
tokens: &[AddedToken],
|
||||
model: &dyn Model,
|
||||
normalizer: Option<&dyn Normalizer>,
|
||||
) -> usize {
|
||||
let mut ignored = 0;
|
||||
for token in tokens {
|
||||
if token.content.is_empty() {
|
||||
ignored += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
let id = if let Some(id) = model.token_to_id(&token.content) {
|
||||
ignored += 1;
|
||||
id
|
||||
} else {
|
||||
let new_id = (self.original_vocab_size + self.added_tokens_map.len()) as u32;
|
||||
self.added_tokens_map.insert(token.content.clone(), new_id);
|
||||
|
||||
if !self.special_tokens_set.contains(&token.content) {
|
||||
self.added_tokens.push(token.clone());
|
||||
}
|
||||
|
||||
new_id
|
||||
};
|
||||
|
||||
// Update the current revert operation
|
||||
self.added_tokens_map_r
|
||||
.entry(id)
|
||||
.and_modify(|t| *t = token.clone())
|
||||
.or_insert_with(|| token.clone());
|
||||
}
|
||||
|
||||
self.refresh_added_tokens(normalizer);
|
||||
|
||||
// Return the number of added tokens
|
||||
tokens.len() - ignored
|
||||
}
|
||||
|
||||
/// Reconstruct our internal RegexSet when new tokens are added to the vocabulary.
|
||||
///
|
||||
/// We keep two different RegexSet, one that will take care of matching against the
|
||||
/// non-normalized string, and one matching against the normalized one.
|
||||
fn refresh_added_tokens(&mut self, normalizer: Option<&dyn Normalizer>) {
|
||||
type TupleTokenId<'a> = (&'a AddedToken, u32);
|
||||
let (normalized, non_normalized): (Vec<TupleTokenId>, Vec<TupleTokenId>) = self
|
||||
.special_tokens
|
||||
.iter()
|
||||
.chain(self.added_tokens.iter())
|
||||
// TODO: Fix this: special tokens that are part of the original vocabulary are
|
||||
// not part of the `self.added_tokens_map` and so it crashes.
|
||||
.map(|token| (token, self.added_tokens_map[&token.content]))
|
||||
.partition(|(token, _)| token.normalized);
|
||||
|
||||
let (tokens, ids): (Vec<&AddedToken>, Vec<u32>) = non_normalized.into_iter().unzip();
|
||||
self.split_re = (
|
||||
regex::RegexSet::new(tokens.iter().map(|t| t.get_pattern(normalizer))).unwrap(),
|
||||
ids,
|
||||
);
|
||||
|
||||
let (tokens, ids): (Vec<&AddedToken>, Vec<u32>) = normalized.into_iter().unzip();
|
||||
self.split_normalized_re = (
|
||||
regex::RegexSet::new(tokens.iter().map(|t| t.get_pattern(normalizer))).unwrap(),
|
||||
ids,
|
||||
);
|
||||
}
|
||||
|
||||
/// TODO: Add doc string here
|
||||
fn extract(
|
||||
&self,
|
||||
sentence: NormalizedString,
|
||||
split_re: &MatchingSet,
|
||||
) -> Vec<(NormalizedString, Option<u32>)> {
|
||||
let mut matches = split_re
|
||||
.0
|
||||
.matches(sentence.get())
|
||||
.into_iter()
|
||||
.flat_map(|idx| {
|
||||
regex::Regex::new(&split_re.0.patterns()[idx])
|
||||
.unwrap()
|
||||
.find_iter(sentence.get())
|
||||
.map(|m| (idx, (m.start(), m.end())))
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// We sort all the matches by their start and then by their pattern id
|
||||
matches.sort_by(
|
||||
|(idxa, (sa, _)), (idxb, (sb, _))| {
|
||||
if sa != sb {
|
||||
sa.cmp(sb)
|
||||
} else {
|
||||
idxa.cmp(idxb)
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
// Select the matches (if some are overlapping) we want to keep
|
||||
let mut i = 0;
|
||||
let mut current_offset = 0;
|
||||
let mut splits = Vec::with_capacity(matches.len());
|
||||
while i < matches.len() {
|
||||
let (idx, (start, end)) = matches[i];
|
||||
|
||||
// current match is before the currentt offset, let's skip it
|
||||
if start < current_offset {
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Find out if we have overlapping neighbors. If so, we keep the one with the lowest
|
||||
// idx, and apply it, then continue. All others will be skipped since `current_offset`
|
||||
// will have been increased
|
||||
if i + 1 < matches.len() {
|
||||
if let Some((idx, (s, e))) = matches[i..]
|
||||
.iter()
|
||||
.take_while(|(_, (s, e))| *s < end && start < *e)
|
||||
.min() // Order on idx first
|
||||
.copied()
|
||||
{
|
||||
splits.push((idx, (s, e)));
|
||||
current_offset = e;
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// We didn't find overlapping neighbors, apply ourself
|
||||
splits.push((idx, (start, end)));
|
||||
current_offset = end;
|
||||
i += 1;
|
||||
}
|
||||
|
||||
// We also insert the splits that are inbetween the added tokens, to split the entire string
|
||||
let mut start_offset = 0;
|
||||
let mut splits = splits
|
||||
.into_iter()
|
||||
.flat_map(|(idx, (start, end))| {
|
||||
let mut splits = vec![];
|
||||
if start_offset < start {
|
||||
splits.push((None, (start_offset, start)));
|
||||
}
|
||||
splits.push((Some(idx), (start, end)));
|
||||
start_offset = end;
|
||||
|
||||
splits
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
if let Some((_, (_, end))) = splits.iter().last().copied() {
|
||||
if end < sentence.get().len() {
|
||||
splits.push((None, (end, sentence.get().len())));
|
||||
}
|
||||
}
|
||||
|
||||
if splits.is_empty() {
|
||||
vec![(sentence, None)]
|
||||
} else {
|
||||
splits
|
||||
.into_iter()
|
||||
.map(|(idx, (start, end))| {
|
||||
// TODO: Check this works
|
||||
let normalized = sentence
|
||||
.slice_bytes(Range::Normalized(start..end))
|
||||
.expect("Error while extracting normalized Range");
|
||||
|
||||
// Find out the associated AddedToken, and its id
|
||||
let id = idx.map(|idx| split_re.1[idx]);
|
||||
|
||||
(normalized, id)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract the additional vocabulary from the given sentence, normalizing it along the way.
|
||||
///
|
||||
/// Some tokens should match against their normalized representation, as well as the
|
||||
/// non-normalized one. For example, when we expect to extract the token `yesterday` in the
|
||||
/// input sentence `I read a book Yesterday`, if the normalizer is supposed to lowercase
|
||||
/// everything, we expect a match.
|
||||
///
|
||||
/// This method returns a `Vec` of `(NormalizedString, Option<u32>)`, where the optional `u32`
|
||||
/// contains the relevant ID if this is an additional token.
|
||||
pub fn extract_and_normalize(
|
||||
&self,
|
||||
normalizer: Option<&dyn Normalizer>,
|
||||
sentence: &str,
|
||||
) -> Vec<(NormalizedString, Option<u32>)> {
|
||||
// 1. We extract all the non-normalized tokens from the non-normalized string
|
||||
let pieces = self.extract(NormalizedString::from(sentence), &self.split_re);
|
||||
|
||||
// 2. Then extract the normalized tokens from the normalized pieces of the string
|
||||
pieces
|
||||
.into_iter()
|
||||
.flat_map(|(mut normalized, id)| {
|
||||
if id.is_some() {
|
||||
// If the piece has an associated ID, we already extracted something,
|
||||
// so we just return it
|
||||
vec![(normalized, id)]
|
||||
} else {
|
||||
// Otherwise, we need to normalized the string, and then proceed to extracting
|
||||
normalizer.map(|n| n.normalize(&mut normalized));
|
||||
self.extract(normalized, &self.split_normalized_re)
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub(super) struct AddedTokenWithId {
|
||||
/// The id assigned to this token
|
||||
pub id: u32,
|
||||
/// Whether this is a special token
|
||||
pub special: bool,
|
||||
|
||||
#[serde(flatten)]
|
||||
/// The target AddedToken
|
||||
pub token: AddedToken,
|
||||
}
|
||||
|
||||
impl Serialize for AddedVocabulary {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
let mut vocabulary = serializer.serialize_seq(Some(self.added_tokens_map.len()))?;
|
||||
|
||||
let mut added_tokens = self
|
||||
.added_tokens_map_r
|
||||
.iter()
|
||||
.map(|(id, token)| AddedTokenWithId {
|
||||
id: *id,
|
||||
special: self.special_tokens_set.contains(&token.content),
|
||||
token: token.clone(),
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
// We need to have these added tokens ordered by ascending ID
|
||||
added_tokens.sort_unstable_by_key(|o| o.id);
|
||||
|
||||
for token in added_tokens {
|
||||
vocabulary.serialize_element(&token)?;
|
||||
}
|
||||
|
||||
vocabulary.end()
|
||||
}
|
||||
}
|
||||
@@ -211,6 +211,7 @@ impl std::str::FromStr for Tokenizer {
|
||||
impl Tokenizer {
|
||||
/// Instantiate a new Tokenizer, with the given Model
|
||||
pub fn new(model: Box<dyn Model>) -> Self {
|
||||
let original_vocab_size = model.get_vocab_size();
|
||||
Tokenizer {
|
||||
normalizer: None,
|
||||
pre_tokenizer: None,
|
||||
@@ -218,7 +219,7 @@ impl Tokenizer {
|
||||
post_processor: None,
|
||||
decoder: None,
|
||||
|
||||
added_vocabulary: AddedVocabulary::new(),
|
||||
added_vocabulary: AddedVocabulary::new(original_vocab_size),
|
||||
|
||||
truncation: None,
|
||||
padding: None,
|
||||
@@ -302,6 +303,8 @@ impl Tokenizer {
|
||||
/// Set the model
|
||||
pub fn with_model(&mut self, model: Box<dyn Model>) -> &Self {
|
||||
self.model = model;
|
||||
self.added_vocabulary
|
||||
.update_original_vocab_size(self.model.get_vocab_size());
|
||||
self
|
||||
}
|
||||
|
||||
@@ -389,16 +392,14 @@ impl Tokenizer {
|
||||
pub fn normalize(&self, sentence: &str) -> Result<NormalizedString> {
|
||||
let mut normalized = self
|
||||
.added_vocabulary
|
||||
.extract(sentence)
|
||||
.extract_and_normalize(self.normalizer.as_deref(), sentence)
|
||||
.into_iter()
|
||||
.map(|(sentence, id)| -> Result<NormalizedString> {
|
||||
.map(|(mut sentence, id)| -> Result<NormalizedString> {
|
||||
if id.is_some() {
|
||||
Ok(sentence)
|
||||
} else {
|
||||
let mut normalized = self.do_normalize(sentence)?;
|
||||
let _ = self.pre_tokenize(&mut normalized)?;
|
||||
|
||||
Ok(normalized)
|
||||
self.pre_tokenize(&mut sentence)?;
|
||||
Ok(sentence)
|
||||
}
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
@@ -421,35 +422,37 @@ impl Tokenizer {
|
||||
|
||||
let mut sequence_encodings = vec![];
|
||||
for subseq in sequence {
|
||||
let results = self.added_vocabulary.extract(&subseq).into_iter().map(
|
||||
|(sentence, id)| -> Result<(Encoding, NormalizedString)> {
|
||||
if let Some(id) = id {
|
||||
Ok((
|
||||
Encoding::new(
|
||||
vec![id],
|
||||
vec![type_id],
|
||||
vec![sentence.get().to_owned()],
|
||||
vec![Some(0)],
|
||||
vec![(0, sentence.len())],
|
||||
vec![0],
|
||||
vec![1],
|
||||
vec![],
|
||||
),
|
||||
sentence,
|
||||
))
|
||||
} else {
|
||||
// 1. Normalization
|
||||
let mut normalized = self.do_normalize(sentence)?;
|
||||
// 2. Pre tokenization
|
||||
let pre_tokenized = self.pre_tokenize(&mut normalized)?;
|
||||
// 3. Model
|
||||
let tokens = self.model.tokenize(pre_tokenized)?;
|
||||
let encoding = Encoding::from_tokens(tokens, type_id);
|
||||
let results = self
|
||||
.added_vocabulary
|
||||
.extract_and_normalize(self.normalizer.as_deref(), &subseq)
|
||||
.into_iter()
|
||||
.map(
|
||||
|(mut normalized, id)| -> Result<(Encoding, NormalizedString)> {
|
||||
if let Some(id) = id {
|
||||
Ok((
|
||||
Encoding::new(
|
||||
vec![id],
|
||||
vec![type_id],
|
||||
vec![normalized.get().to_owned()],
|
||||
vec![Some(0)],
|
||||
vec![(0, normalized.len())],
|
||||
vec![0],
|
||||
vec![1],
|
||||
vec![],
|
||||
),
|
||||
normalized,
|
||||
))
|
||||
} else {
|
||||
// 1. Pre tokenization
|
||||
let pre_tokenized = self.pre_tokenize(&mut normalized)?;
|
||||
// 2. Model
|
||||
let tokens = self.model.tokenize(pre_tokenized)?;
|
||||
let encoding = Encoding::from_tokens(tokens, type_id);
|
||||
|
||||
Ok((encoding, normalized))
|
||||
}
|
||||
},
|
||||
);
|
||||
Ok((encoding, normalized))
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
let (all_encodings, all_normalized) =
|
||||
ResultShunt::process(results, |iter| iter.unzip::<_, _, Vec<_>, Vec<_>>())?;
|
||||
@@ -675,6 +678,8 @@ impl Tokenizer {
|
||||
|
||||
let (model, special_tokens) = trainer.train(words)?;
|
||||
self.model = model;
|
||||
self.added_vocabulary
|
||||
.update_original_vocab_size(self.model.get_vocab_size());
|
||||
self.add_special_tokens(&special_tokens);
|
||||
|
||||
Ok(())
|
||||
@@ -752,13 +757,16 @@ impl Tokenizer {
|
||||
/// Register the given tokens as special tokens. This is especially useful for removing
|
||||
/// these special tokens while decoding
|
||||
pub fn add_special_tokens(&mut self, tokens: &[AddedToken]) -> usize {
|
||||
self.added_vocabulary
|
||||
.add_special_tokens(tokens, self.model.as_ref())
|
||||
self.added_vocabulary.add_special_tokens(
|
||||
tokens,
|
||||
self.model.as_ref(),
|
||||
self.normalizer.as_deref(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Add the given tokens to the added vocabulary
|
||||
pub fn add_tokens(&mut self, tokens: &[AddedToken]) -> usize {
|
||||
self.added_vocabulary
|
||||
.add_tokens(tokens, self.model.as_ref())
|
||||
.add_tokens(tokens, self.model.as_ref(), self.normalizer.as_deref())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,8 +9,8 @@ fn add_tokens() {
|
||||
|
||||
assert_eq!(
|
||||
tokenizer.add_special_tokens(&[
|
||||
AddedToken::from("<cls>".into()),
|
||||
AddedToken::from("<sep>".into())
|
||||
AddedToken::from("<cls>".into(), true),
|
||||
AddedToken::from("<sep>".into(), true)
|
||||
]),
|
||||
2
|
||||
);
|
||||
@@ -19,8 +19,8 @@ fn add_tokens() {
|
||||
|
||||
assert_eq!(
|
||||
tokenizer.add_tokens(&[
|
||||
AddedToken::from("hello".into()),
|
||||
AddedToken::from("world".into())
|
||||
AddedToken::from("hello".into(), false),
|
||||
AddedToken::from("world".into(), false)
|
||||
]),
|
||||
2
|
||||
);
|
||||
@@ -31,7 +31,7 @@ fn add_tokens() {
|
||||
#[test]
|
||||
fn lstrip_tokens() {
|
||||
let mut tokenizer = get_byte_level(true, false);
|
||||
tokenizer.add_special_tokens(&[AddedToken::from("<mask>".into()).lstrip(true)]);
|
||||
tokenizer.add_special_tokens(&[AddedToken::from("<mask>".into(), true).lstrip(true)]);
|
||||
|
||||
let input = "I saw a <mask> 😺";
|
||||
let output = tokenizer.encode(input, false).unwrap();
|
||||
@@ -49,7 +49,7 @@ fn lstrip_tokens() {
|
||||
#[test]
|
||||
fn rstrip_tokens() {
|
||||
let mut tokenizer = get_byte_level(false, false);
|
||||
tokenizer.add_special_tokens(&[AddedToken::from("<mask>".into()).rstrip(true)]);
|
||||
tokenizer.add_special_tokens(&[AddedToken::from("<mask>".into(), true).rstrip(true)]);
|
||||
|
||||
let input = "I saw a <mask> 😺";
|
||||
let output = tokenizer.encode(input, false).unwrap();
|
||||
@@ -62,7 +62,7 @@ fn rstrip_tokens() {
|
||||
// When `add_prefix_space = true` rstrip cannot work as a prefix space is added
|
||||
// to the next token
|
||||
let mut tokenizer = get_byte_level(true, false);
|
||||
tokenizer.add_special_tokens(&[AddedToken::from("<mask>".into()).rstrip(true)]);
|
||||
tokenizer.add_special_tokens(&[AddedToken::from("<mask>".into(), true).rstrip(true)]);
|
||||
|
||||
let input = "I saw a <mask> 😺";
|
||||
let output = tokenizer.encode(input, false).unwrap();
|
||||
@@ -77,7 +77,7 @@ fn rstrip_tokens() {
|
||||
fn single_word_tokens() {
|
||||
// If `single_word = true` it shouldn't split `dancing`
|
||||
let mut tokenizer = get_byte_level(false, false);
|
||||
tokenizer.add_special_tokens(&[AddedToken::from("ing".into()).single_word(true)]);
|
||||
tokenizer.add_special_tokens(&[AddedToken::from("ing".into(), true).single_word(true)]);
|
||||
|
||||
let input = "I like dancing";
|
||||
let output = tokenizer.encode(input, false).unwrap();
|
||||
@@ -86,7 +86,7 @@ fn single_word_tokens() {
|
||||
|
||||
// If `single_word = false` it should split `dancing`
|
||||
let mut tokenizer = get_byte_level(false, false);
|
||||
tokenizer.add_special_tokens(&[AddedToken::from("ing".into()).single_word(false)]);
|
||||
tokenizer.add_special_tokens(&[AddedToken::from("ing".into(), true).single_word(false)]);
|
||||
|
||||
let input = "I like dancing";
|
||||
let output = tokenizer.encode(input, false).unwrap();
|
||||
@@ -98,9 +98,9 @@ fn single_word_tokens() {
|
||||
fn overlapping_tokens() {
|
||||
let mut tokenizer = get_byte_level(false, false);
|
||||
|
||||
tokenizer.add_special_tokens(&[AddedToken::from("danc".into())]);
|
||||
tokenizer.add_special_tokens(&[AddedToken::from("nci".into())]);
|
||||
tokenizer.add_special_tokens(&[AddedToken::from("ing".into())]);
|
||||
tokenizer.add_special_tokens(&[AddedToken::from("danc".into(), true)]);
|
||||
tokenizer.add_special_tokens(&[AddedToken::from("nci".into(), true)]);
|
||||
tokenizer.add_special_tokens(&[AddedToken::from("ing".into(), true)]);
|
||||
|
||||
let input = "I like dancing";
|
||||
let output = tokenizer.encode(input, false).unwrap();
|
||||
@@ -109,10 +109,10 @@ fn overlapping_tokens() {
|
||||
|
||||
let mut tokenizer = get_byte_level(false, false);
|
||||
|
||||
tokenizer.add_special_tokens(&[AddedToken::from("nci".into())]);
|
||||
tokenizer.add_special_tokens(&[AddedToken::from("danc".into())]);
|
||||
tokenizer.add_special_tokens(&[AddedToken::from("ing".into())]);
|
||||
tokenizer.add_special_tokens(&[AddedToken::from("ike".into())]);
|
||||
tokenizer.add_special_tokens(&[AddedToken::from("nci".into(), true)]);
|
||||
tokenizer.add_special_tokens(&[AddedToken::from("danc".into(), true)]);
|
||||
tokenizer.add_special_tokens(&[AddedToken::from("ing".into(), true)]);
|
||||
tokenizer.add_special_tokens(&[AddedToken::from("ike".into(), true)]);
|
||||
|
||||
let output = tokenizer.encode(input, false).unwrap();
|
||||
|
||||
|
||||
@@ -158,7 +158,7 @@ fn split_on_added_tokens_bert() {
|
||||
let input = "Yesterday I saw a [MASK] far away";
|
||||
|
||||
let mut tokenizer = get_bert();
|
||||
tokenizer.add_special_tokens(&[AddedToken::from("[MASK]".into())]);
|
||||
tokenizer.add_special_tokens(&[AddedToken::from("[MASK]".into(), true)]);
|
||||
let output = tokenizer.encode(input, false).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
|
||||
Reference in New Issue
Block a user