Merge pull request #1335 from ArthurZucker/update-added-tokens

Update added tokens
This commit is contained in:
Arthur
2023-09-07 12:48:54 +02:00
committed by GitHub
8 changed files with 215 additions and 67 deletions

View File

@ -30,10 +30,12 @@ class AddedToken:
text. For example, with the added token ``"yesterday"``, and a normalizer in charge of
lowercasing the text, the token could be extract from the input ``"I saw a lion
Yesterday"``.
special (:obj:`bool`, defaults to :obj:`False` with :meth:`~tokenizers.Tokenizer.add_tokens` and :obj:`False` with :meth:`~tokenizers.Tokenizer.add_special_tokens`):
Defines whether this token should be skipped when decoding.
"""
def __init__(self, content, single_word=False, lstrip=False, rstrip=False, normalized=True):
def __init__(self, content, single_word=False, lstrip=False, rstrip=False, normalized=True, special=False):
pass
@property
def content(self):
@ -65,6 +67,12 @@ class AddedToken:
Get the value of the :obj:`single_word` option
"""
pass
@property
def special(self):
"""
Get the value of the :obj:`special` option
"""
pass
class Encoding:
"""
@ -891,6 +899,14 @@ class Tokenizer:
:class:`~tokenizers.Tokenizer`: The new tokenizer
"""
pass
def get_added_tokens_decoder(self):
"""
Get the underlying vocabulary
Returns:
:obj:`Dict[int, AddedToken]`: The vocabulary
"""
pass
def get_vocab(self, with_added_tokens=True):
"""
Get the underlying vocabulary

View File

@ -42,6 +42,14 @@ class BaseTokenizer:
"""
return self._tokenizer.get_vocab(with_added_tokens=with_added_tokens)
def get_added_tokens_decoder(self) -> Dict[int, AddedToken]:
"""Returns the added reverse vocabulary
Returns:
The added vocabulary mapping ints to AddedTokens
"""
return self._tokenizer.get_added_tokens_decoder()
def get_vocab_size(self, with_added_tokens: bool = True) -> int:
"""Return the size of vocabulary, with or without added tokens.

View File

@ -25,6 +25,7 @@ use super::pre_tokenizers::PyPreTokenizer;
use super::trainers::PyTrainer;
use crate::processors::PyPostProcessor;
use crate::utils::{MaybeSizedIterator, PyBufferedIterator};
use std::collections::BTreeMap;
/// Represents a token that can be be added to a :class:`~tokenizers.Tokenizer`.
/// It can have special options that defines the way it should behave.
@ -55,21 +56,23 @@ use crate::utils::{MaybeSizedIterator, PyBufferedIterator};
/// text. For example, with the added token ``"yesterday"``, and a normalizer in charge of
/// lowercasing the text, the token could be extract from the input ``"I saw a lion
/// Yesterday"``.
/// special (:obj:`bool`, defaults to :obj:`False` with :meth:`~tokenizers.Tokenizer.add_tokens` and :obj:`False` with :meth:`~tokenizers.Tokenizer.add_special_tokens`):
/// Defines whether this token should be skipped when decoding.
///
#[pyclass(dict, module = "tokenizers", name = "AddedToken")]
pub struct PyAddedToken {
pub content: String,
pub is_special_token: bool,
pub special: bool,
pub single_word: Option<bool>,
pub lstrip: Option<bool>,
pub rstrip: Option<bool>,
pub normalized: Option<bool>,
}
impl PyAddedToken {
pub fn from<S: Into<String>>(content: S, is_special_token: Option<bool>) -> Self {
pub fn from<S: Into<String>>(content: S, special: Option<bool>) -> Self {
Self {
content: content.into(),
is_special_token: is_special_token.unwrap_or(false),
special: special.unwrap_or(false),
single_word: None,
lstrip: None,
rstrip: None,
@ -78,7 +81,7 @@ impl PyAddedToken {
}
pub fn get_token(&self) -> tk::tokenizer::AddedToken {
let mut token = tk::AddedToken::from(&self.content, self.is_special_token);
let mut token = tk::AddedToken::from(&self.content, self.special);
if let Some(sw) = self.single_word {
token = token.single_word(sw);
@ -105,6 +108,7 @@ impl PyAddedToken {
dict.set_item("lstrip", token.lstrip)?;
dict.set_item("rstrip", token.rstrip)?;
dict.set_item("normalized", token.normalized)?;
dict.set_item("special", token.special)?;
Ok(dict)
}
@ -118,7 +122,7 @@ impl From<tk::AddedToken> for PyAddedToken {
lstrip: Some(token.lstrip),
rstrip: Some(token.rstrip),
normalized: Some(token.normalized),
is_special_token: !token.normalized,
special: token.special,
}
}
}
@ -126,7 +130,7 @@ impl From<tk::AddedToken> for PyAddedToken {
#[pymethods]
impl PyAddedToken {
#[new]
#[pyo3(signature = (content=None, **kwargs), text_signature = "(self, content, single_word=False, lstrip=False, rstrip=False, normalized=True)")]
#[pyo3(signature = (content=None, **kwargs), text_signature = "(self, content, single_word=False, lstrip=False, rstrip=False, normalized=True, special=False)")]
fn __new__(content: Option<&str>, kwargs: Option<&PyDict>) -> PyResult<Self> {
let mut token = PyAddedToken::from(content.unwrap_or(""), None);
@ -138,6 +142,7 @@ impl PyAddedToken {
"lstrip" => token.lstrip = Some(value.extract()?),
"rstrip" => token.rstrip = Some(value.extract()?),
"normalized" => token.normalized = Some(value.extract()?),
"special" => token.special = value.extract()?,
_ => println!("Ignored unknown kwarg option {}", key),
}
}
@ -161,6 +166,7 @@ impl PyAddedToken {
"lstrip" => self.lstrip = Some(value.extract()?),
"rstrip" => self.rstrip = Some(value.extract()?),
"normalized" => self.normalized = Some(value.extract()?),
"special" => self.special = value.extract()?,
_ => {}
}
}
@ -176,6 +182,12 @@ impl PyAddedToken {
&self.content
}
/// Set the content of this :obj:`AddedToken`
#[setter]
fn set_content(&mut self, content: String) {
self.content = content;
}
/// Get the value of the :obj:`rstrip` option
#[getter]
fn get_rstrip(&self) -> bool {
@ -199,6 +211,17 @@ impl PyAddedToken {
fn get_normalized(&self) -> bool {
self.get_token().normalized
}
/// Get the value of the :obj:`special` option
#[getter]
fn get_special(&self) -> bool {
self.get_token().special
}
/// Set the value of the :obj:`special` option
#[setter]
fn set_special(&mut self, special: bool) {
self.special = special;
}
fn __str__(&self) -> PyResult<&str> {
Ok(&self.content)
@ -212,12 +235,13 @@ impl PyAddedToken {
let token = self.get_token();
Ok(format!(
"AddedToken(\"{}\", rstrip={}, lstrip={}, single_word={}, normalized={})",
"AddedToken(\"{}\", rstrip={}, lstrip={}, single_word={}, normalized={}, special={})",
self.content,
bool_to_python(token.rstrip),
bool_to_python(token.lstrip),
bool_to_python(token.single_word),
bool_to_python(token.normalized)
bool_to_python(token.normalized),
bool_to_python(token.special)
))
}
@ -639,6 +663,22 @@ impl PyTokenizer {
self.tokenizer.get_vocab(with_added_tokens)
}
/// Get the underlying vocabulary
///
/// Returns:
/// :obj:`Dict[int, AddedToken]`: The vocabulary
#[pyo3(signature = ())]
#[pyo3(text_signature = "(self)")]
fn get_added_tokens_decoder(&self) -> BTreeMap<u32, PyAddedToken> {
let mut sorted_map = BTreeMap::new();
for (key, value) in self.tokenizer.get_added_tokens_decoder() {
sorted_map.insert(key, value.into());
}
sorted_map
}
/// Get the size of the underlying vocabulary
///
/// Args:
@ -1090,7 +1130,7 @@ impl PyTokenizer {
if let Ok(content) = token.extract::<String>() {
Ok(PyAddedToken::from(content, Some(false)).get_token())
} else if let Ok(mut token) = token.extract::<PyRefMut<PyAddedToken>>() {
token.is_special_token = false;
token.special = false;
Ok(token.get_token())
} else {
Err(exceptions::PyTypeError::new_err(
@ -1127,7 +1167,7 @@ impl PyTokenizer {
if let Ok(content) = token.extract::<String>() {
Ok(tk::tokenizer::AddedToken::from(content, true))
} else if let Ok(mut token) = token.extract::<PyRefMut<PyAddedToken>>() {
token.is_special_token = true;
token.special = true;
Ok(token.get_token())
} else {
Err(exceptions::PyTypeError::new_err(

View File

@ -226,7 +226,7 @@ impl PyBpeTrainer {
if let Ok(content) = token.extract::<String>() {
Ok(tk::tokenizer::AddedToken::from(content, true))
} else if let Ok(mut token) = token.extract::<PyRefMut<PyAddedToken>>() {
token.is_special_token = true;
token.special = true;
Ok(token.get_token())
} else {
Err(exceptions::PyTypeError::new_err(
@ -319,7 +319,7 @@ impl PyBpeTrainer {
} else if let Ok(mut token) =
token.extract::<PyRefMut<PyAddedToken>>()
{
token.is_special_token = true;
token.special = true;
Ok(token.get_token())
} else {
Err(exceptions::PyTypeError::new_err(
@ -440,7 +440,7 @@ impl PyWordPieceTrainer {
if let Ok(content) = token.extract::<String>() {
Ok(tk::tokenizer::AddedToken::from(content, true))
} else if let Ok(mut token) = token.extract::<PyRefMut<PyAddedToken>>() {
token.is_special_token = true;
token.special = true;
Ok(token.get_token())
} else {
Err(exceptions::PyTypeError::new_err(
@ -526,7 +526,7 @@ impl PyWordPieceTrainer {
} else if let Ok(mut token) =
token.extract::<PyRefMut<PyAddedToken>>()
{
token.is_special_token = true;
token.special = true;
Ok(token.get_token())
} else {
Err(exceptions::PyTypeError::new_err(
@ -632,7 +632,7 @@ impl PyWordLevelTrainer {
if let Ok(content) = token.extract::<String>() {
Ok(tk::tokenizer::AddedToken::from(content, true))
} else if let Ok(mut token) = token.extract::<PyRefMut<PyAddedToken>>() {
token.is_special_token = true;
token.special = true;
Ok(token.get_token())
} else {
Err(exceptions::PyTypeError::new_err(
@ -673,7 +673,7 @@ impl PyWordLevelTrainer {
} else if let Ok(mut token) =
token.extract::<PyRefMut<PyAddedToken>>()
{
token.is_special_token = true;
token.special = true;
Ok(token.get_token())
} else {
Err(exceptions::PyTypeError::new_err(
@ -778,7 +778,7 @@ impl PyUnigramTrainer {
if let Ok(content) = token.extract::<String>() {
Ok(tk::tokenizer::AddedToken::from(content, true))
} else if let Ok(mut token) = token.extract::<PyRefMut<PyAddedToken>>() {
token.is_special_token = true;
token.special = true;
Ok(token.get_token())
} else {
Err(exceptions::PyTypeError::new_err(
@ -846,7 +846,7 @@ impl PyUnigramTrainer {
} else if let Ok(mut token) =
token.extract::<PyRefMut<PyAddedToken>>()
{
token.is_special_token = true;
token.special = true;
Ok(token.get_token())
} else {
Err(exceptions::PyTypeError::new_err(

View File

@ -16,10 +16,19 @@ from ..utils import bert_files, data_dir, multiprocessing_with_parallelism, robe
class TestAddedToken:
def test_instantiate_with_content_only(self):
added_token = AddedToken("<mask>")
added_token.content = "<MASK>"
assert added_token.content == "<MASK>"
assert type(added_token) == AddedToken
added_token.content = added_token.content.lower()
assert added_token.special == False
added_token.special = True
assert added_token.special == True
added_token.special = False
assert str(added_token) == "<mask>"
assert (
repr(added_token) == 'AddedToken("<mask>", rstrip=False, lstrip=False, single_word=False, normalized=True)'
repr(added_token)
== 'AddedToken("<mask>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False)'
)
assert added_token.rstrip == False
assert added_token.lstrip == False
@ -365,6 +374,16 @@ class TestTokenizer:
vocab = tokenizer.get_vocab(with_added_tokens=False)
assert vocab == {}
# Can retrieve added token decoder
vocab = tokenizer.get_added_tokens_decoder()
assert vocab == {
0: AddedToken("my", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),
1: AddedToken("name", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),
2: AddedToken("is", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),
3: AddedToken("john", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),
4: AddedToken("pair", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),
}
def test_get_vocab_size(self):
tokenizer = Tokenizer(BPE())
tokenizer.add_tokens(["my", "name", "is", "john", "pair"])

View File

@ -34,8 +34,8 @@ class TestBpeTrainer:
assert trainer.min_frequency == 12
assert trainer.show_progress == False
assert trainer.special_tokens == [
AddedToken("1"),
AddedToken("2"),
AddedToken("1", special=True),
AddedToken("2", special=True),
]
assert trainer.limit_alphabet == 13
assert sorted(trainer.initial_alphabet) == ["a", "b", "c"]
@ -91,8 +91,8 @@ class TestWordPieceTrainer:
assert trainer.min_frequency == 12
assert trainer.show_progress == False
assert trainer.special_tokens == [
AddedToken("1"),
AddedToken("2"),
AddedToken("1", special=True),
AddedToken("2", special=True),
]
assert trainer.limit_alphabet == 13
assert sorted(trainer.initial_alphabet) == ["a", "b", "c"]
@ -131,8 +131,8 @@ class TestWordLevelTrainer:
assert trainer.min_frequency == 12
assert trainer.show_progress == False
assert trainer.special_tokens == [
AddedToken("1"),
AddedToken("2"),
AddedToken("1", special=True),
AddedToken("2", special=True),
]
# Modify these
@ -272,8 +272,8 @@ class TestUnigram:
assert trainer.vocab_size == 12345
assert trainer.show_progress == False
assert trainer.special_tokens == [
AddedToken("1", normalized=False),
AddedToken("2", lstrip=True, normalized=False),
AddedToken("1", normalized=False, special=True),
AddedToken("2", lstrip=True, normalized=False, special=True),
]
assert sorted(trainer.initial_alphabet) == ["a", "b", "c"]

View File

@ -11,7 +11,7 @@ use std::collections::{HashMap, HashSet};
/// like:
/// - Whether they should only match single words
/// - Whether to include any whitespace on its left or right
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct AddedToken {
/// The content of the added token
pub content: String,
@ -66,6 +66,12 @@ impl AddedToken {
self.normalized = normalized;
self
}
/// Specify whether this token is special, meaning if it should be skipped when decoding
#[must_use]
pub fn special(mut self, special: bool) -> Self {
self.special = special;
self
}
}
impl Default for AddedToken {
fn default() -> Self {
@ -79,19 +85,12 @@ impl Default for AddedToken {
}
}
}
// We only want to hash on the content. AddedToken cannot be added multiple times with different
// options
// AddedTokens can be updated if value changed
impl std::hash::Hash for AddedToken {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.content.hash(state);
}
}
impl std::cmp::PartialEq for AddedToken {
fn eq(&self, other: &Self) -> bool {
self.content == other.content
}
}
impl std::cmp::Eq for AddedToken {}
type MatchingSet = (AhoCorasick, Vec<u32>);
@ -181,8 +180,8 @@ impl AddedVocabulary {
split_normalized_trie: (normalized_trie, vec![]),
}
}
/// Size of the additional vocabulary
#[allow(dead_code)] // Suppress the "method is never used" warning
pub fn len(&self) -> usize {
self.added_tokens_map.len()
}
@ -192,6 +191,11 @@ impl AddedVocabulary {
&self.added_tokens_map
}
/// Get the additional vocabulary with the AddedTokens
pub fn get_added_tokens_decoder(&self) -> &HashMap<u32, AddedToken> {
&self.added_tokens_map_r
}
/// Get the id matching one of our token if it exists
pub fn token_to_id(&self, token: &str, model: &impl Model) -> Option<u32> {
self.added_tokens_map
@ -244,30 +248,42 @@ impl AddedVocabulary {
// Then we delegate to `add_tokens`, that will take care of refreshing added tokens too.
let mut ignored = 0;
for token in tokens {
if token.content.is_empty() {
if token.content.is_empty() || self.added_tokens_map_r.values().any(|val| val == token)
{
ignored += 1;
continue;
}
let id = if let Some(id) = self.token_to_id(&token.content, model) {
ignored += 1;
id
// If a token is already part of the vocabulary, we mark it as added
let new_id = if let Some(new_id) = self.token_to_id(&token.content, model) {
new_id
} else {
let new_id = (model.get_vocab_size() + self.added_tokens_map.len()) as u32;
self.added_tokens_map.insert(token.content.clone(), new_id);
self.added_tokens_map.values().cloned().max().map_or(
model.get_vocab_size() as u32,
|max| {
if (max >= model.get_vocab_size() as u32) || model.get_vocab_size() == 0 {
max + 1
} else {
model.get_vocab_size() as u32
}
},
)
};
// Make sure we modify the previous entry
self.added_tokens_map
.entry(token.content.clone())
.and_modify(|old_id| *old_id = new_id)
.or_insert_with(|| new_id);
// Update the current revert operation
self.added_tokens_map_r
.entry(new_id)
.and_modify(|t| *t = token.clone())
.or_insert_with(|| token.clone());
// Make sure to remove previous entry (if the token gets a new id)
// Finally add the token to the classic set if special
if !self.special_tokens_set.contains(&token.content) {
self.added_tokens.push(token.clone());
}
new_id
};
// Update the current revert operation
self.added_tokens_map_r
.entry(id)
.and_modify(|t| *t = token.clone())
.or_insert_with(|| token.clone());
}
self.refresh_added_tokens(model, normalizer);
@ -569,7 +585,9 @@ mod tests {
),
1
);
assert_eq!(vocab.len(), 1);
let vocab_len: usize = vocab.len();
assert_eq!(vocab_len, 1);
// Does not add multiple time the same token
assert_eq!(
@ -585,12 +603,15 @@ mod tests {
);
assert_eq!(vocab.len(), 2);
// Does not add tokens already covered by the model
// Also adds tokens already covered by the model
let added_token = AddedToken::from("test", false);
assert_eq!(
vocab.add_tokens(&[AddedToken::from("test", false)], &model, normalizer),
0
vocab.add_tokens(&[added_token.clone()], &model, normalizer),
1
);
assert_eq!(vocab.len(), 2);
assert_eq!(vocab.len(), 3);
assert_eq!(vocab.get_added_tokens_decoder()[&0], added_token);
}
#[test]
@ -626,11 +647,47 @@ mod tests {
// Can add tokens already covered by the model
assert_eq!(
vocab.add_special_tokens(&[AddedToken::from("test", true)], &model, normalizer),
0
1
);
assert_eq!(vocab.len(), 2); // Did not add a new token, since it exist in the original model
assert_eq!(vocab.len(), 3); // New token was added
assert!(vocab.is_special_token("test"));
assert!(!vocab.added_tokens_map.contains_key("test"));
assert_eq!(
*vocab.get_added_tokens_decoder(),
HashMap::from([
(0, AddedToken::from("test", true)),
(2, AddedToken::from("added_token_1", true)),
(3, AddedToken::from("added_token_2", true)),
])
);
assert!(vocab.added_tokens_map.contains_key("test"));
assert!(vocab.added_tokens_map_r.contains_key(&0));
vocab.add_tokens(
&[
AddedToken::from("tost", true),
AddedToken::from("another_two", false),
],
&model,
normalizer,
);
assert_eq!(vocab.len(), 5); // New token was added
assert_eq!(vocab.get_vocab()["another_two"], 4); // New token was added, but the index is not the length of the vocab
// Let's add an already added token again
assert_eq!(
vocab.add_special_tokens(&[AddedToken::from("another_two", true)], &model, normalizer),
1
);
assert_eq!(vocab.len(), 5); // Token was already there
assert_eq!(vocab.get_vocab()["another_two"], 4); // Token idx not changed
// Just checking that we can set the content of the string in rust
let mut token: AddedToken = AddedToken::from("Hey", false);
token.content = "hey".to_string();
assert_eq!(token.content, "hey"); // Token was already there
token.special = true;
assert!(token.special); // Token was already there
}
#[test]
@ -766,6 +823,8 @@ mod tests {
let mut vocab = AddedVocabulary::new();
let normalizer = Lowercase;
assert_eq!(vocab.len(), 0);
vocab.add_tokens(
&[AddedToken::from("<mask>", false).single_word(true)],
&model,

View File

@ -659,13 +659,19 @@ where
final_vocab
}
/// Get the added tokens decoder
pub fn get_added_tokens_decoder(&self) -> HashMap<u32, AddedToken> {
self.added_vocabulary.get_added_tokens_decoder().clone()
}
/// Get the size of the vocabulary
pub fn get_vocab_size(&self, with_added_tokens: bool) -> usize {
self.model.get_vocab_size()
+ if with_added_tokens {
self.added_vocabulary.len()
// TODO ArthurZ THIS IS WRONG! We need to measure the length of the `set` because
// now some tokens can be both in the added_tokens_encoder and in the vocab
if with_added_tokens {
self.get_vocab(true).len()
} else {
0
self.model.get_vocab_size()
}
}