Rust - Add AddedVocabulary + normalized option on AddedToken

This commit is contained in:
Anthony MOI
2020-06-15 22:46:30 -04:00
parent 7dff86b704
commit 397cc539da
7 changed files with 562 additions and 74 deletions

View File

@@ -28,8 +28,8 @@ pub struct AddedToken {
impl AddedToken { impl AddedToken {
#[new] #[new]
#[args(kwargs = "**")] #[args(kwargs = "**")]
fn new(content: &str, kwargs: Option<&PyDict>) -> PyResult<Self> { fn new(content: &str, is_special_token: bool, kwargs: Option<&PyDict>) -> PyResult<Self> {
let mut token = tk::tokenizer::AddedToken::from(content.to_owned()); let mut token = tk::tokenizer::AddedToken::from(content.to_owned(), is_special_token);
if let Some(kwargs) = kwargs { if let Some(kwargs) = kwargs {
for (key, value) in kwargs { for (key, value) in kwargs {
@@ -38,6 +38,7 @@ impl AddedToken {
"single_word" => token = token.single_word(value.extract()?), "single_word" => token = token.single_word(value.extract()?),
"lstrip" => token = token.lstrip(value.extract()?), "lstrip" => token = token.lstrip(value.extract()?),
"rstrip" => token = token.rstrip(value.extract()?), "rstrip" => token = token.rstrip(value.extract()?),
"normalized" => token = token.normalized(value.extract()?),
_ => println!("Ignored unknown kwarg option {}", key), _ => println!("Ignored unknown kwarg option {}", key),
} }
} }
@@ -65,6 +66,11 @@ impl AddedToken {
fn get_single_word(&self) -> bool { fn get_single_word(&self) -> bool {
self.token.single_word self.token.single_word
} }
#[getter]
fn get_normalized(&self) -> bool {
self.token.normalized
}
} }
#[pyproto] #[pyproto]
impl PyObjectProtocol for AddedToken { impl PyObjectProtocol for AddedToken {
@@ -533,7 +539,7 @@ impl Tokenizer {
self.tokenizer.token_to_id(token) self.tokenizer.token_to_id(token)
} }
fn id_to_token(&self, id: u32) -> Option<String> { fn id_to_token(&self, id: u32) -> Option<&str> {
self.tokenizer.id_to_token(id) self.tokenizer.id_to_token(id)
} }
@@ -542,10 +548,7 @@ impl Tokenizer {
.into_iter() .into_iter()
.map(|token| { .map(|token| {
if let Ok(content) = token.extract::<String>() { if let Ok(content) = token.extract::<String>() {
Ok(tk::tokenizer::AddedToken { Ok(tk::tokenizer::AddedToken::from(content, false))
content,
..Default::default()
})
} else if let Ok(token) = token.extract::<PyRef<AddedToken>>() { } else if let Ok(token) = token.extract::<PyRef<AddedToken>>() {
Ok(token.token.clone()) Ok(token.token.clone())
} else { } else {
@@ -564,10 +567,7 @@ impl Tokenizer {
.into_iter() .into_iter()
.map(|token| { .map(|token| {
if let Ok(content) = token.extract::<String>() { if let Ok(content) = token.extract::<String>() {
Ok(tk::tokenizer::AddedToken { Ok(tk::tokenizer::AddedToken::from(content, true))
content,
..Default::default()
})
} else if let Ok(token) = token.extract::<PyRef<AddedToken>>() { } else if let Ok(token) = token.extract::<PyRef<AddedToken>>() {
Ok(token.token.clone()) Ok(token.token.clone())
} else { } else {

View File

@@ -17,10 +17,9 @@ fn create_gpt2_tokenizer(bpe: BPE) -> Tokenizer {
let mut tokenizer = Tokenizer::new(Box::new(bpe)); let mut tokenizer = Tokenizer::new(Box::new(bpe));
tokenizer.with_pre_tokenizer(Box::new(ByteLevel::default())); tokenizer.with_pre_tokenizer(Box::new(ByteLevel::default()));
tokenizer.with_decoder(Box::new(ByteLevel::default())); tokenizer.with_decoder(Box::new(ByteLevel::default()));
tokenizer.add_tokens(&[ tokenizer.add_tokens(&[AddedToken::from(String::from("ing"), false).single_word(false)]);
AddedToken::from(String::from("ing")).single_word(false), tokenizer
AddedToken::from(String::from("[ENT]")).single_word(true), .add_special_tokens(&[AddedToken::from(String::from("[ENT]"), true).single_word(true)]);
]);
tokenizer tokenizer
} }

View File

@@ -21,10 +21,9 @@ fn shell(matches: &ArgMatches) -> Result<()> {
tokenizer.with_pre_tokenizer(Box::new(ByteLevel::default())); tokenizer.with_pre_tokenizer(Box::new(ByteLevel::default()));
tokenizer.with_decoder(Box::new(ByteLevel::default())); tokenizer.with_decoder(Box::new(ByteLevel::default()));
tokenizer.add_tokens(&[ tokenizer.add_tokens(&[AddedToken::from(String::from("ing"), false).single_word(false)]);
AddedToken::from(String::from("ing")).single_word(false), tokenizer
AddedToken::from(String::from("[ENT]")).single_word(true), .add_special_tokens(&[AddedToken::from(String::from("[ENT]"), true).single_word(true)]);
]);
let stdin = io::stdin(); let stdin = io::stdin();
let mut handle = stdin.lock(); let mut handle = stdin.lock();

View File

@@ -0,0 +1,482 @@
use super::{Model, NormalizedString, Normalizer, Range};
use serde::{ser::SerializeSeq, Deserialize, Serialize, Serializer};
use std::collections::{HashMap, HashSet};
/// Represent a token added by the user on top of the existing Model vocabulary.
/// AddedToken can be configured to specify the behavior they should have in various situations
/// like:
/// - Whether they should only match single words
/// - Whether to include any whitespace on its left or right
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AddedToken {
/// The content of the added token
pub content: String,
/// Whether this token must be a single word or can break words
pub single_word: bool,
/// Whether this token should strip whitespaces on its left
pub lstrip: bool,
/// Whether this token should strip whitespaces on its right
pub rstrip: bool,
/// Whether this token should be normalized
pub normalized: bool,
}
impl AddedToken {
/// Build this token from the given content, specifying if it is intented to be a
/// special token. Special tokens are not normalized by default.
pub fn from(content: String, special_token: bool) -> Self {
AddedToken {
content,
normalized: !special_token,
..Default::default()
}
}
/// Specify whether this token should only match on whole single words, and never
/// part of a word.
pub fn single_word(mut self, single_word: bool) -> Self {
self.single_word = single_word;
self
}
/// Specify whether this token should include all the whitespaces on its left, in
/// order to strip them out.
pub fn lstrip(mut self, lstrip: bool) -> Self {
self.lstrip = lstrip;
self
}
/// Specify whether this token should include all the whitespaces on its right, in
/// order to strip them out.
pub fn rstrip(mut self, rstrip: bool) -> Self {
self.rstrip = rstrip;
self
}
/// Specify whether this token should be normalized, and/or match against its normalized
/// version in the input text.
pub fn normalized(mut self, normalized: bool) -> Self {
self.normalized = normalized;
self
}
/// Retrive the pattern built for this token, according to all the specified parameters.
pub fn get_pattern(&self, normalizer: Option<&dyn Normalizer>) -> String {
let mut r = if self.single_word {
let first_b = self
.content
.chars()
.next()
.map(|c| {
if regex_syntax::is_word_character(c) {
r"\b"
} else {
""
}
})
.unwrap();
let last_b = self
.content
.chars()
.last()
.map(|c| {
if regex_syntax::is_word_character(c) {
r"\b"
} else {
""
}
})
.unwrap();
// Normalize the content
let mut content = NormalizedString::from(&self.content);
normalizer.map(|n| n.normalize(&mut content));
format!(r"{}{}{}", first_b, regex::escape(content.get()), last_b)
} else {
regex::escape(&self.content)
};
if self.lstrip && self.rstrip {
r = format!(r"(\s)?{}(\s)?", r);
} else if self.lstrip {
r = format!(r"(\s)?{}", r);
} else if self.rstrip {
r = format!(r"{}(\s)?", r);
}
r
}
}
impl Default for AddedToken {
fn default() -> Self {
AddedToken {
content: String::new(),
single_word: false,
lstrip: false,
rstrip: false,
normalized: false,
}
}
}
// We only want to hash on the content. AddedToken cannot be added multiple times with different
// options
impl std::hash::Hash for AddedToken {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.content.hash(state);
}
}
impl std::cmp::PartialEq for AddedToken {
fn eq(&self, other: &Self) -> bool {
self.content == other.content
}
}
impl std::cmp::Eq for AddedToken {}
type MatchingSet = (regex::RegexSet, Vec<u32>);
///
/// A vocabulary built on top of the Model
///
/// This provides a way to add new vocabulary to a Tokenizer that has already been trained,
/// in a previous process, maybe by someone else. This is especially interesting in the case
/// of fine-tunings, where we want to finetune a model while adding some new functionalities
/// using some new special tokens, or maybe add some tokens in the case of unknown tokens, etc.
///
/// One of the reasons we need to handle these tokens outside of the model is simply that
/// for many models, it is not possible to add new tokens after the training process. For example,
/// using BPE, the training process generates merges pairs along the vocabulary, and any token
/// in the vocabulary can be decomposed in other tokens, down to the original alphabet. If we
/// were to add new tokens after this training process, we couldn't make sure the merges pairs
/// exist as required.
///
pub(super) struct AddedVocabulary {
/// The size of the original vocabulary. This is what we use to determine the new
/// ids we need to generate
original_vocab_size: usize,
/// Contains the mapping from String to ID as the user intended it. This map
/// contains both special tokens and classic added tokens.
added_tokens_map: HashMap<String, u32>,
/// Contains the mapping from ID to AddedToken for all the added tokens, both special
/// and classic.
added_tokens_map_r: HashMap<u32, AddedToken>,
/// Contains only the classic AddedToken, in the specific order the user gave them.
added_tokens: Vec<AddedToken>,
/// Contains only the special AddedToken, in the specific order the user gave them.
special_tokens: Vec<AddedToken>,
/// A Set, containing all the special token for easy access while decoding. This let's
/// use remove them easily with an O(1) complexity.
special_tokens_set: HashSet<String>,
/// A RegexSet containing all the non-normalized patterns used to split on AddedTokens
split_re: MatchingSet,
/// A RegexSet containing all the normalized patterns used to split on AddedTokens
split_normalized_re: MatchingSet,
}
impl AddedVocabulary {
pub fn new(original_vocab_size: usize) -> Self {
Self {
original_vocab_size,
added_tokens_map: HashMap::new(),
added_tokens_map_r: HashMap::new(),
added_tokens: vec![],
special_tokens: vec![],
special_tokens_set: HashSet::new(),
split_re: (regex::RegexSet::new::<_, &&str>(&[]).unwrap(), vec![]),
split_normalized_re: (regex::RegexSet::new::<_, &&str>(&[]).unwrap(), vec![]),
}
}
/// Sets the original vocabulary size. We need this value to return IDs that
/// are shifted according to the original vocabulary.
pub fn update_original_vocab_size(&mut self, size: usize) {
self.original_vocab_size = size;
}
/// Size of the additional vocabulary
pub fn len(&self) -> usize {
self.added_tokens_map.len()
}
/// Get the additional vocabulary
pub fn get_vocab(&self) -> &HashMap<String, u32> {
&self.added_tokens_map
}
/// Get the id matching one of our token if it exists
pub fn token_to_id(&self, token: &str) -> Option<&u32> {
self.added_tokens_map.get(token)
}
/// Get the token matching the given id if it exists
pub fn id_to_token(&self, id: u32) -> Option<&str> {
self.added_tokens_map_r.get(&id).map(|t| t.content.as_ref())
}
/// Check if a token is a special token
pub fn is_special_token(&self, token: &str) -> bool {
self.special_tokens_set.contains(token)
}
/// Add some special tokens to the vocabulary
pub fn add_special_tokens(
&mut self,
tokens: &[AddedToken],
model: &dyn Model,
normalizer: Option<&dyn Normalizer>,
) -> usize {
for token in tokens {
if !self.special_tokens_set.contains(&token.content) {
self.special_tokens.push(token.to_owned());
self.special_tokens_set.insert(token.content.clone());
}
}
let added = self.add_tokens(&tokens, model, normalizer);
self.refresh_added_tokens(normalizer);
added
}
/// Add some tokens to the vocabulary
pub fn add_tokens(
&mut self,
tokens: &[AddedToken],
model: &dyn Model,
normalizer: Option<&dyn Normalizer>,
) -> usize {
let mut ignored = 0;
for token in tokens {
if token.content.is_empty() {
ignored += 1;
continue;
}
let id = if let Some(id) = model.token_to_id(&token.content) {
ignored += 1;
id
} else {
let new_id = (self.original_vocab_size + self.added_tokens_map.len()) as u32;
self.added_tokens_map.insert(token.content.clone(), new_id);
if !self.special_tokens_set.contains(&token.content) {
self.added_tokens.push(token.clone());
}
new_id
};
// Update the current revert operation
self.added_tokens_map_r
.entry(id)
.and_modify(|t| *t = token.clone())
.or_insert_with(|| token.clone());
}
self.refresh_added_tokens(normalizer);
// Return the number of added tokens
tokens.len() - ignored
}
/// Reconstruct our internal RegexSet when new tokens are added to the vocabulary.
///
/// We keep two different RegexSet, one that will take care of matching against the
/// non-normalized string, and one matching against the normalized one.
fn refresh_added_tokens(&mut self, normalizer: Option<&dyn Normalizer>) {
type TupleTokenId<'a> = (&'a AddedToken, u32);
let (normalized, non_normalized): (Vec<TupleTokenId>, Vec<TupleTokenId>) = self
.special_tokens
.iter()
.chain(self.added_tokens.iter())
// TODO: Fix this: special tokens that are part of the original vocabulary are
// not part of the `self.added_tokens_map` and so it crashes.
.map(|token| (token, self.added_tokens_map[&token.content]))
.partition(|(token, _)| token.normalized);
let (tokens, ids): (Vec<&AddedToken>, Vec<u32>) = non_normalized.into_iter().unzip();
self.split_re = (
regex::RegexSet::new(tokens.iter().map(|t| t.get_pattern(normalizer))).unwrap(),
ids,
);
let (tokens, ids): (Vec<&AddedToken>, Vec<u32>) = normalized.into_iter().unzip();
self.split_normalized_re = (
regex::RegexSet::new(tokens.iter().map(|t| t.get_pattern(normalizer))).unwrap(),
ids,
);
}
/// TODO: Add doc string here
fn extract(
&self,
sentence: NormalizedString,
split_re: &MatchingSet,
) -> Vec<(NormalizedString, Option<u32>)> {
let mut matches = split_re
.0
.matches(sentence.get())
.into_iter()
.flat_map(|idx| {
regex::Regex::new(&split_re.0.patterns()[idx])
.unwrap()
.find_iter(sentence.get())
.map(|m| (idx, (m.start(), m.end())))
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();
// We sort all the matches by their start and then by their pattern id
matches.sort_by(
|(idxa, (sa, _)), (idxb, (sb, _))| {
if sa != sb {
sa.cmp(sb)
} else {
idxa.cmp(idxb)
}
},
);
// Select the matches (if some are overlapping) we want to keep
let mut i = 0;
let mut current_offset = 0;
let mut splits = Vec::with_capacity(matches.len());
while i < matches.len() {
let (idx, (start, end)) = matches[i];
// current match is before the currentt offset, let's skip it
if start < current_offset {
i += 1;
continue;
}
// Find out if we have overlapping neighbors. If so, we keep the one with the lowest
// idx, and apply it, then continue. All others will be skipped since `current_offset`
// will have been increased
if i + 1 < matches.len() {
if let Some((idx, (s, e))) = matches[i..]
.iter()
.take_while(|(_, (s, e))| *s < end && start < *e)
.min() // Order on idx first
.copied()
{
splits.push((idx, (s, e)));
current_offset = e;
i += 1;
continue;
}
}
// We didn't find overlapping neighbors, apply ourself
splits.push((idx, (start, end)));
current_offset = end;
i += 1;
}
// We also insert the splits that are inbetween the added tokens, to split the entire string
let mut start_offset = 0;
let mut splits = splits
.into_iter()
.flat_map(|(idx, (start, end))| {
let mut splits = vec![];
if start_offset < start {
splits.push((None, (start_offset, start)));
}
splits.push((Some(idx), (start, end)));
start_offset = end;
splits
})
.collect::<Vec<_>>();
if let Some((_, (_, end))) = splits.iter().last().copied() {
if end < sentence.get().len() {
splits.push((None, (end, sentence.get().len())));
}
}
if splits.is_empty() {
vec![(sentence, None)]
} else {
splits
.into_iter()
.map(|(idx, (start, end))| {
// TODO: Check this works
let normalized = sentence
.slice_bytes(Range::Normalized(start..end))
.expect("Error while extracting normalized Range");
// Find out the associated AddedToken, and its id
let id = idx.map(|idx| split_re.1[idx]);
(normalized, id)
})
.collect()
}
}
/// Extract the additional vocabulary from the given sentence, normalizing it along the way.
///
/// Some tokens should match against their normalized representation, as well as the
/// non-normalized one. For example, when we expect to extract the token `yesterday` in the
/// input sentence `I read a book Yesterday`, if the normalizer is supposed to lowercase
/// everything, we expect a match.
///
/// This method returns a `Vec` of `(NormalizedString, Option<u32>)`, where the optional `u32`
/// contains the relevant ID if this is an additional token.
pub fn extract_and_normalize(
&self,
normalizer: Option<&dyn Normalizer>,
sentence: &str,
) -> Vec<(NormalizedString, Option<u32>)> {
// 1. We extract all the non-normalized tokens from the non-normalized string
let pieces = self.extract(NormalizedString::from(sentence), &self.split_re);
// 2. Then extract the normalized tokens from the normalized pieces of the string
pieces
.into_iter()
.flat_map(|(mut normalized, id)| {
if id.is_some() {
// If the piece has an associated ID, we already extracted something,
// so we just return it
vec![(normalized, id)]
} else {
// Otherwise, we need to normalized the string, and then proceed to extracting
normalizer.map(|n| n.normalize(&mut normalized));
self.extract(normalized, &self.split_normalized_re)
}
})
.collect::<Vec<_>>()
}
}
#[derive(Debug, Serialize, Deserialize)]
pub(super) struct AddedTokenWithId {
/// The id assigned to this token
pub id: u32,
/// Whether this is a special token
pub special: bool,
#[serde(flatten)]
/// The target AddedToken
pub token: AddedToken,
}
impl Serialize for AddedVocabulary {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut vocabulary = serializer.serialize_seq(Some(self.added_tokens_map.len()))?;
let mut added_tokens = self
.added_tokens_map_r
.iter()
.map(|(id, token)| AddedTokenWithId {
id: *id,
special: self.special_tokens_set.contains(&token.content),
token: token.clone(),
})
.collect::<Vec<_>>();
// We need to have these added tokens ordered by ascending ID
added_tokens.sort_unstable_by_key(|o| o.id);
for token in added_tokens {
vocabulary.serialize_element(&token)?;
}
vocabulary.end()
}
}

View File

@@ -211,6 +211,7 @@ impl std::str::FromStr for Tokenizer {
impl Tokenizer { impl Tokenizer {
/// Instantiate a new Tokenizer, with the given Model /// Instantiate a new Tokenizer, with the given Model
pub fn new(model: Box<dyn Model>) -> Self { pub fn new(model: Box<dyn Model>) -> Self {
let original_vocab_size = model.get_vocab_size();
Tokenizer { Tokenizer {
normalizer: None, normalizer: None,
pre_tokenizer: None, pre_tokenizer: None,
@@ -218,7 +219,7 @@ impl Tokenizer {
post_processor: None, post_processor: None,
decoder: None, decoder: None,
added_vocabulary: AddedVocabulary::new(), added_vocabulary: AddedVocabulary::new(original_vocab_size),
truncation: None, truncation: None,
padding: None, padding: None,
@@ -302,6 +303,8 @@ impl Tokenizer {
/// Set the model /// Set the model
pub fn with_model(&mut self, model: Box<dyn Model>) -> &Self { pub fn with_model(&mut self, model: Box<dyn Model>) -> &Self {
self.model = model; self.model = model;
self.added_vocabulary
.update_original_vocab_size(self.model.get_vocab_size());
self self
} }
@@ -389,16 +392,14 @@ impl Tokenizer {
pub fn normalize(&self, sentence: &str) -> Result<NormalizedString> { pub fn normalize(&self, sentence: &str) -> Result<NormalizedString> {
let mut normalized = self let mut normalized = self
.added_vocabulary .added_vocabulary
.extract(sentence) .extract_and_normalize(self.normalizer.as_deref(), sentence)
.into_iter() .into_iter()
.map(|(sentence, id)| -> Result<NormalizedString> { .map(|(mut sentence, id)| -> Result<NormalizedString> {
if id.is_some() { if id.is_some() {
Ok(sentence) Ok(sentence)
} else { } else {
let mut normalized = self.do_normalize(sentence)?; self.pre_tokenize(&mut sentence)?;
let _ = self.pre_tokenize(&mut normalized)?; Ok(sentence)
Ok(normalized)
} }
}) })
.collect::<Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
@@ -421,28 +422,30 @@ impl Tokenizer {
let mut sequence_encodings = vec![]; let mut sequence_encodings = vec![];
for subseq in sequence { for subseq in sequence {
let results = self.added_vocabulary.extract(&subseq).into_iter().map( let results = self
|(sentence, id)| -> Result<(Encoding, NormalizedString)> { .added_vocabulary
.extract_and_normalize(self.normalizer.as_deref(), &subseq)
.into_iter()
.map(
|(mut normalized, id)| -> Result<(Encoding, NormalizedString)> {
if let Some(id) = id { if let Some(id) = id {
Ok(( Ok((
Encoding::new( Encoding::new(
vec![id], vec![id],
vec![type_id], vec![type_id],
vec![sentence.get().to_owned()], vec![normalized.get().to_owned()],
vec![Some(0)], vec![Some(0)],
vec![(0, sentence.len())], vec![(0, normalized.len())],
vec![0], vec![0],
vec![1], vec![1],
vec![], vec![],
), ),
sentence, normalized,
)) ))
} else { } else {
// 1. Normalization // 1. Pre tokenization
let mut normalized = self.do_normalize(sentence)?;
// 2. Pre tokenization
let pre_tokenized = self.pre_tokenize(&mut normalized)?; let pre_tokenized = self.pre_tokenize(&mut normalized)?;
// 3. Model // 2. Model
let tokens = self.model.tokenize(pre_tokenized)?; let tokens = self.model.tokenize(pre_tokenized)?;
let encoding = Encoding::from_tokens(tokens, type_id); let encoding = Encoding::from_tokens(tokens, type_id);
@@ -675,6 +678,8 @@ impl Tokenizer {
let (model, special_tokens) = trainer.train(words)?; let (model, special_tokens) = trainer.train(words)?;
self.model = model; self.model = model;
self.added_vocabulary
.update_original_vocab_size(self.model.get_vocab_size());
self.add_special_tokens(&special_tokens); self.add_special_tokens(&special_tokens);
Ok(()) Ok(())
@@ -752,13 +757,16 @@ impl Tokenizer {
/// Register the given tokens as special tokens. This is especially useful for removing /// Register the given tokens as special tokens. This is especially useful for removing
/// these special tokens while decoding /// these special tokens while decoding
pub fn add_special_tokens(&mut self, tokens: &[AddedToken]) -> usize { pub fn add_special_tokens(&mut self, tokens: &[AddedToken]) -> usize {
self.added_vocabulary self.added_vocabulary.add_special_tokens(
.add_special_tokens(tokens, self.model.as_ref()) tokens,
self.model.as_ref(),
self.normalizer.as_deref(),
)
} }
/// Add the given tokens to the added vocabulary /// Add the given tokens to the added vocabulary
pub fn add_tokens(&mut self, tokens: &[AddedToken]) -> usize { pub fn add_tokens(&mut self, tokens: &[AddedToken]) -> usize {
self.added_vocabulary self.added_vocabulary
.add_tokens(tokens, self.model.as_ref()) .add_tokens(tokens, self.model.as_ref(), self.normalizer.as_deref())
} }
} }

View File

@@ -9,8 +9,8 @@ fn add_tokens() {
assert_eq!( assert_eq!(
tokenizer.add_special_tokens(&[ tokenizer.add_special_tokens(&[
AddedToken::from("<cls>".into()), AddedToken::from("<cls>".into(), true),
AddedToken::from("<sep>".into()) AddedToken::from("<sep>".into(), true)
]), ]),
2 2
); );
@@ -19,8 +19,8 @@ fn add_tokens() {
assert_eq!( assert_eq!(
tokenizer.add_tokens(&[ tokenizer.add_tokens(&[
AddedToken::from("hello".into()), AddedToken::from("hello".into(), false),
AddedToken::from("world".into()) AddedToken::from("world".into(), false)
]), ]),
2 2
); );
@@ -31,7 +31,7 @@ fn add_tokens() {
#[test] #[test]
fn lstrip_tokens() { fn lstrip_tokens() {
let mut tokenizer = get_byte_level(true, false); let mut tokenizer = get_byte_level(true, false);
tokenizer.add_special_tokens(&[AddedToken::from("<mask>".into()).lstrip(true)]); tokenizer.add_special_tokens(&[AddedToken::from("<mask>".into(), true).lstrip(true)]);
let input = "I saw a <mask> 😺"; let input = "I saw a <mask> 😺";
let output = tokenizer.encode(input, false).unwrap(); let output = tokenizer.encode(input, false).unwrap();
@@ -49,7 +49,7 @@ fn lstrip_tokens() {
#[test] #[test]
fn rstrip_tokens() { fn rstrip_tokens() {
let mut tokenizer = get_byte_level(false, false); let mut tokenizer = get_byte_level(false, false);
tokenizer.add_special_tokens(&[AddedToken::from("<mask>".into()).rstrip(true)]); tokenizer.add_special_tokens(&[AddedToken::from("<mask>".into(), true).rstrip(true)]);
let input = "I saw a <mask> 😺"; let input = "I saw a <mask> 😺";
let output = tokenizer.encode(input, false).unwrap(); let output = tokenizer.encode(input, false).unwrap();
@@ -62,7 +62,7 @@ fn rstrip_tokens() {
// When `add_prefix_space = true` rstrip cannot work as a prefix space is added // When `add_prefix_space = true` rstrip cannot work as a prefix space is added
// to the next token // to the next token
let mut tokenizer = get_byte_level(true, false); let mut tokenizer = get_byte_level(true, false);
tokenizer.add_special_tokens(&[AddedToken::from("<mask>".into()).rstrip(true)]); tokenizer.add_special_tokens(&[AddedToken::from("<mask>".into(), true).rstrip(true)]);
let input = "I saw a <mask> 😺"; let input = "I saw a <mask> 😺";
let output = tokenizer.encode(input, false).unwrap(); let output = tokenizer.encode(input, false).unwrap();
@@ -77,7 +77,7 @@ fn rstrip_tokens() {
fn single_word_tokens() { fn single_word_tokens() {
// If `single_word = true` it shouldn't split `dancing` // If `single_word = true` it shouldn't split `dancing`
let mut tokenizer = get_byte_level(false, false); let mut tokenizer = get_byte_level(false, false);
tokenizer.add_special_tokens(&[AddedToken::from("ing".into()).single_word(true)]); tokenizer.add_special_tokens(&[AddedToken::from("ing".into(), true).single_word(true)]);
let input = "I like dancing"; let input = "I like dancing";
let output = tokenizer.encode(input, false).unwrap(); let output = tokenizer.encode(input, false).unwrap();
@@ -86,7 +86,7 @@ fn single_word_tokens() {
// If `single_word = false` it should split `dancing` // If `single_word = false` it should split `dancing`
let mut tokenizer = get_byte_level(false, false); let mut tokenizer = get_byte_level(false, false);
tokenizer.add_special_tokens(&[AddedToken::from("ing".into()).single_word(false)]); tokenizer.add_special_tokens(&[AddedToken::from("ing".into(), true).single_word(false)]);
let input = "I like dancing"; let input = "I like dancing";
let output = tokenizer.encode(input, false).unwrap(); let output = tokenizer.encode(input, false).unwrap();
@@ -98,9 +98,9 @@ fn single_word_tokens() {
fn overlapping_tokens() { fn overlapping_tokens() {
let mut tokenizer = get_byte_level(false, false); let mut tokenizer = get_byte_level(false, false);
tokenizer.add_special_tokens(&[AddedToken::from("danc".into())]); tokenizer.add_special_tokens(&[AddedToken::from("danc".into(), true)]);
tokenizer.add_special_tokens(&[AddedToken::from("nci".into())]); tokenizer.add_special_tokens(&[AddedToken::from("nci".into(), true)]);
tokenizer.add_special_tokens(&[AddedToken::from("ing".into())]); tokenizer.add_special_tokens(&[AddedToken::from("ing".into(), true)]);
let input = "I like dancing"; let input = "I like dancing";
let output = tokenizer.encode(input, false).unwrap(); let output = tokenizer.encode(input, false).unwrap();
@@ -109,10 +109,10 @@ fn overlapping_tokens() {
let mut tokenizer = get_byte_level(false, false); let mut tokenizer = get_byte_level(false, false);
tokenizer.add_special_tokens(&[AddedToken::from("nci".into())]); tokenizer.add_special_tokens(&[AddedToken::from("nci".into(), true)]);
tokenizer.add_special_tokens(&[AddedToken::from("danc".into())]); tokenizer.add_special_tokens(&[AddedToken::from("danc".into(), true)]);
tokenizer.add_special_tokens(&[AddedToken::from("ing".into())]); tokenizer.add_special_tokens(&[AddedToken::from("ing".into(), true)]);
tokenizer.add_special_tokens(&[AddedToken::from("ike".into())]); tokenizer.add_special_tokens(&[AddedToken::from("ike".into(), true)]);
let output = tokenizer.encode(input, false).unwrap(); let output = tokenizer.encode(input, false).unwrap();

View File

@@ -158,7 +158,7 @@ fn split_on_added_tokens_bert() {
let input = "Yesterday I saw a [MASK] far away"; let input = "Yesterday I saw a [MASK] far away";
let mut tokenizer = get_bert(); let mut tokenizer = get_bert();
tokenizer.add_special_tokens(&[AddedToken::from("[MASK]".into())]); tokenizer.add_special_tokens(&[AddedToken::from("[MASK]".into(), true)]);
let output = tokenizer.encode(input, false).unwrap(); let output = tokenizer.encode(input, false).unwrap();
assert_eq!( assert_eq!(