Encode special tokens (#1437)

* add doc in the code

* add option to skip special tokens

* nits

* add api dummy for now

* Fmt.

* Fix fmt.

* Fix the stub.

* add a test

* add a test in python

* style it

* nits

* add getter and setters

* stub

* update python test

* fmt

* last nit

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
Arthur
2024-01-19 12:43:43 +01:00
committed by GitHub
parent 888dd4bc65
commit 6a77d4859b
5 changed files with 173 additions and 0 deletions

View File

@ -836,6 +836,18 @@ class Tokenizer:
Returns:
A :obj:`List` of :class:`~tokenizers.Encoding`: The encoded batch
"""
pass
@property
def encode_special_tokens(self):
"""
Modifies the tokenizer in order to use or not the special tokens
during encoding.
Args:
value (:obj:`bool`):
Whether to use the special tokens or not
"""
pass
@staticmethod

View File

@ -1109,6 +1109,25 @@ impl PyTokenizer {
self.tokenizer.id_to_token(id)
}
/// Modifies the tokenizer in order to use or not the special tokens
/// during encoding.
///
/// Args:
/// value (:obj:`bool`):
/// Whether to use the special tokens or not
///
#[setter]
fn set_encode_special_tokens(&mut self, value: bool) {
self.tokenizer.set_encode_special_tokens(value);
}
/// Get the value of the `encode_special_tokens` attribute
///
/// Returns:
/// :obj:`bool`: the tokenizer's encode_special_tokens attribute
#[getter]
fn get_encode_special_tokens(&self) -> bool {
self.tokenizer.get_encode_special_tokens()
}
/// Add the given tokens to the vocabulary
///
/// The given tokens are added only if they don't already exist in the vocabulary.

View File

@ -457,3 +457,34 @@ class TestTokenizer:
output = tokenizer.encode("A sentence 🤗")
assert output.ids == [1, 10, 2, 3, 4, 5, 10, 6, 7, 8, 9]
assert output.tokens == ["A", " ", "sen", "te", "n", "ce", " ", "<0xF0>", "<0x9F>", "<0xA4>", "<0x97>"]
def test_encode_special_tokens(self):
tokenizer = Tokenizer.from_pretrained("t5-base")
tokenizer.add_tokens(["<eot>"])
tokenizer.add_special_tokens(["<end_of_text>"])
output = tokenizer.encode("Hey there<end_of_text> dear<eot>friend!", add_special_tokens=False)
assert output.tokens == ["▁Hey", "▁there", "<end_of_text>", "▁dear", "<eot>", "▁friend", "!"]
tokenizer.encode_special_tokens = True
assert tokenizer.encode_special_tokens == True
output = tokenizer.encode("Hey there<end_of_text> dear<eot>friend!", add_special_tokens=False)
assert output.tokens == [
"▁Hey",
"▁there",
"<",
"end",
"_",
"of",
"_",
"text",
">",
"▁dear",
"<eot>",
"▁friend",
"!",
]
tokenizer.add_tokens(["of_text>"])
output = tokenizer.encode("Hey there<end_of_text> dear<eot>friend!", add_special_tokens=False)
assert output.tokens == ["▁Hey", "▁there", "<", "end", "_", "of_text>", "▁dear", "<eot>", "▁friend", "!"]

View File

@ -160,6 +160,9 @@ pub(super) struct AddedVocabulary {
split_trie: MatchingSet,
/// A RegexSet containing all the normalized patterns used to split on AddedTokens
split_normalized_trie: MatchingSet,
/// Whether or not special tokens should be splitted when encoding. This is equivalent to ignoring them
encode_special_tokens: bool,
}
impl AddedVocabulary {
@ -180,6 +183,7 @@ impl AddedVocabulary {
special_tokens_set: HashSet::new(),
split_trie: (trie, vec![]),
split_normalized_trie: (normalized_trie, vec![]),
encode_special_tokens: false,
}
}
/// Size of the additional vocabulary
@ -214,6 +218,15 @@ impl AddedVocabulary {
.or_else(|| model.id_to_token(id))
}
//
pub fn set_encode_special_tokens(&mut self, value: bool) {
self.encode_special_tokens = value;
}
pub fn get_encode_special_tokens(&self) -> bool {
self.encode_special_tokens
}
/// Check if a token is a special token
pub fn is_special_token(&self, token: &str) -> bool {
self.special_tokens_set.contains(token)
@ -356,6 +369,12 @@ impl AddedVocabulary {
let aho_id = mat.pattern();
let id = split_re.1[aho_id];
let added_token = &self.added_tokens_map_r.get(&id).unwrap();
if self.encode_special_tokens && self.special_tokens_set.contains(&added_token.content)
{
continue;
}
if added_token.single_word {
let start_space = start == 0 || !ends_with_word(&sentence[..start]);
let stop_space = stop == sentence.len() || !starts_with_word(&sentence[stop..]);
@ -436,6 +455,18 @@ impl AddedVocabulary {
.split(|_, sequence| Ok(self.split_with_indices(sequence, &self.split_trie)))
.expect("AddedVocabulary bad split");
// <s> normalized = False
// "I read a book <s>Hey" -> "I read a book", " <s>", "Hey"
// </s> normalized = True -> "▁</s>"
// "I read a book</s>Hey" -> "I read a book</s>Hey"
// Day normalized = True -> "Day"
// "I read a book monday" -> "I read a book monday"
// [DAY] normalized = False -> "Day"
// "I read a [DAY] monday" -> "I read a " "[DAY]", "book monday"
// 320055
// 2. Then extract the normalized tokens from the normalized pieces of the string
pretokenized
.split(|_, mut sequence| {
@ -444,6 +475,14 @@ impl AddedVocabulary {
})
.expect("AddedVocabulary bad split");
// ["I read a book", " <s>", "Hey"] -> ["▁I read a book", "▁ <s>", "▁Hey"]
// ["▁I read a book", "▁ <s>", "▁Hey"] -> [.., "▁ ", "<s>", "▁Hey"]
// </s> normalized = True -> "▁</s>"
// "I read a book</s>Hey" -> ["▁I read a book", "<","/","s",">", "Hey"]
// "I read a " "[DAY]", "book monday" -> "i read a " "[day]", "book monday"
pretokenized
}
}
@ -880,4 +919,66 @@ mod tests {
]
);
}
#[test]
fn test_encode_special_tokens() {
let model = ModelMock::new(&[]);
let mut vocab = AddedVocabulary::new();
let normalizer = Lowercase;
vocab.add_tokens(
&[
AddedToken::from("<mask>", true)
.lstrip(true)
.rstrip(true)
.single_word(true),
AddedToken::from("ask>", false),
AddedToken::from("<pad>", true),
],
&model,
Some(&normalizer),
);
vocab.set_encode_special_tokens(true);
let result = vocab.extract_and_normalize(
Some(&normalizer),
"Hi <mask> there\t<mask>\t<mask>\u{2000} <pad> <mask><pad><pad>",
);
assert_eq!(
simplify_output(&result),
vec![
("hi <m", None),
("ask>", Some(vec![1])),
(" there\t<m", None),
("ask>", Some(vec![1])),
("\t<m", None),
("ask>", Some(vec![1])),
("\u{2000} <pad> <m", None),
("ask>", Some(vec![1])),
("<pad><pad>", None)
]
);
vocab.set_encode_special_tokens(false);
let result = vocab.extract_and_normalize(
Some(&normalizer),
"Hi <mask> there\t<mask>\t<mask>\u{2000} <pad> <mask><pad><pad>",
);
assert_eq!(
simplify_output(&result),
vec![
("hi", None),
(" <mask> ", Some(vec![0])),
("there", None),
("\t<mask>\t", Some(vec![0])),
("<mask>\u{2000} ", Some(vec![0])),
("<pad>", Some(vec![2])),
(" <mask>", Some(vec![0])),
("<pad>", Some(vec![2])),
("<pad>", Some(vec![2]))
]
);
}
}

View File

@ -685,6 +685,16 @@ where
self.added_vocabulary.id_to_token(id, &self.model)
}
/// set the added bocab's splitting scheme
pub fn set_encode_special_tokens(&mut self, value: bool) {
self.added_vocabulary.set_encode_special_tokens(value);
}
/// Get added token value
pub fn get_encode_special_tokens(&self) -> bool {
self.added_vocabulary.get_encode_special_tokens()
}
/// Encode a single sequence
fn encode_single_sequence(
&self,