mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Merge pull request #1335 from ArthurZucker/update-added-tokens
Update added tokens
This commit is contained in:
@ -30,10 +30,12 @@ class AddedToken:
|
||||
text. For example, with the added token ``"yesterday"``, and a normalizer in charge of
|
||||
lowercasing the text, the token could be extract from the input ``"I saw a lion
|
||||
Yesterday"``.
|
||||
special (:obj:`bool`, defaults to :obj:`False` with :meth:`~tokenizers.Tokenizer.add_tokens` and :obj:`False` with :meth:`~tokenizers.Tokenizer.add_special_tokens`):
|
||||
Defines whether this token should be skipped when decoding.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, content, single_word=False, lstrip=False, rstrip=False, normalized=True):
|
||||
def __init__(self, content, single_word=False, lstrip=False, rstrip=False, normalized=True, special=False):
|
||||
pass
|
||||
@property
|
||||
def content(self):
|
||||
@ -65,6 +67,12 @@ class AddedToken:
|
||||
Get the value of the :obj:`single_word` option
|
||||
"""
|
||||
pass
|
||||
@property
|
||||
def special(self):
|
||||
"""
|
||||
Get the value of the :obj:`special` option
|
||||
"""
|
||||
pass
|
||||
|
||||
class Encoding:
|
||||
"""
|
||||
@ -891,6 +899,14 @@ class Tokenizer:
|
||||
:class:`~tokenizers.Tokenizer`: The new tokenizer
|
||||
"""
|
||||
pass
|
||||
def get_added_tokens_decoder(self):
|
||||
"""
|
||||
Get the underlying vocabulary
|
||||
|
||||
Returns:
|
||||
:obj:`Dict[int, AddedToken]`: The vocabulary
|
||||
"""
|
||||
pass
|
||||
def get_vocab(self, with_added_tokens=True):
|
||||
"""
|
||||
Get the underlying vocabulary
|
||||
|
@ -42,6 +42,14 @@ class BaseTokenizer:
|
||||
"""
|
||||
return self._tokenizer.get_vocab(with_added_tokens=with_added_tokens)
|
||||
|
||||
def get_added_tokens_decoder(self) -> Dict[int, AddedToken]:
|
||||
"""Returns the added reverse vocabulary
|
||||
|
||||
Returns:
|
||||
The added vocabulary mapping ints to AddedTokens
|
||||
"""
|
||||
return self._tokenizer.get_added_tokens_decoder()
|
||||
|
||||
def get_vocab_size(self, with_added_tokens: bool = True) -> int:
|
||||
"""Return the size of vocabulary, with or without added tokens.
|
||||
|
||||
|
@ -25,6 +25,7 @@ use super::pre_tokenizers::PyPreTokenizer;
|
||||
use super::trainers::PyTrainer;
|
||||
use crate::processors::PyPostProcessor;
|
||||
use crate::utils::{MaybeSizedIterator, PyBufferedIterator};
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
/// Represents a token that can be be added to a :class:`~tokenizers.Tokenizer`.
|
||||
/// It can have special options that defines the way it should behave.
|
||||
@ -55,21 +56,23 @@ use crate::utils::{MaybeSizedIterator, PyBufferedIterator};
|
||||
/// text. For example, with the added token ``"yesterday"``, and a normalizer in charge of
|
||||
/// lowercasing the text, the token could be extract from the input ``"I saw a lion
|
||||
/// Yesterday"``.
|
||||
/// special (:obj:`bool`, defaults to :obj:`False` with :meth:`~tokenizers.Tokenizer.add_tokens` and :obj:`False` with :meth:`~tokenizers.Tokenizer.add_special_tokens`):
|
||||
/// Defines whether this token should be skipped when decoding.
|
||||
///
|
||||
#[pyclass(dict, module = "tokenizers", name = "AddedToken")]
|
||||
pub struct PyAddedToken {
|
||||
pub content: String,
|
||||
pub is_special_token: bool,
|
||||
pub special: bool,
|
||||
pub single_word: Option<bool>,
|
||||
pub lstrip: Option<bool>,
|
||||
pub rstrip: Option<bool>,
|
||||
pub normalized: Option<bool>,
|
||||
}
|
||||
impl PyAddedToken {
|
||||
pub fn from<S: Into<String>>(content: S, is_special_token: Option<bool>) -> Self {
|
||||
pub fn from<S: Into<String>>(content: S, special: Option<bool>) -> Self {
|
||||
Self {
|
||||
content: content.into(),
|
||||
is_special_token: is_special_token.unwrap_or(false),
|
||||
special: special.unwrap_or(false),
|
||||
single_word: None,
|
||||
lstrip: None,
|
||||
rstrip: None,
|
||||
@ -78,7 +81,7 @@ impl PyAddedToken {
|
||||
}
|
||||
|
||||
pub fn get_token(&self) -> tk::tokenizer::AddedToken {
|
||||
let mut token = tk::AddedToken::from(&self.content, self.is_special_token);
|
||||
let mut token = tk::AddedToken::from(&self.content, self.special);
|
||||
|
||||
if let Some(sw) = self.single_word {
|
||||
token = token.single_word(sw);
|
||||
@ -105,6 +108,7 @@ impl PyAddedToken {
|
||||
dict.set_item("lstrip", token.lstrip)?;
|
||||
dict.set_item("rstrip", token.rstrip)?;
|
||||
dict.set_item("normalized", token.normalized)?;
|
||||
dict.set_item("special", token.special)?;
|
||||
|
||||
Ok(dict)
|
||||
}
|
||||
@ -118,7 +122,7 @@ impl From<tk::AddedToken> for PyAddedToken {
|
||||
lstrip: Some(token.lstrip),
|
||||
rstrip: Some(token.rstrip),
|
||||
normalized: Some(token.normalized),
|
||||
is_special_token: !token.normalized,
|
||||
special: token.special,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -126,7 +130,7 @@ impl From<tk::AddedToken> for PyAddedToken {
|
||||
#[pymethods]
|
||||
impl PyAddedToken {
|
||||
#[new]
|
||||
#[pyo3(signature = (content=None, **kwargs), text_signature = "(self, content, single_word=False, lstrip=False, rstrip=False, normalized=True)")]
|
||||
#[pyo3(signature = (content=None, **kwargs), text_signature = "(self, content, single_word=False, lstrip=False, rstrip=False, normalized=True, special=False)")]
|
||||
fn __new__(content: Option<&str>, kwargs: Option<&PyDict>) -> PyResult<Self> {
|
||||
let mut token = PyAddedToken::from(content.unwrap_or(""), None);
|
||||
|
||||
@ -138,6 +142,7 @@ impl PyAddedToken {
|
||||
"lstrip" => token.lstrip = Some(value.extract()?),
|
||||
"rstrip" => token.rstrip = Some(value.extract()?),
|
||||
"normalized" => token.normalized = Some(value.extract()?),
|
||||
"special" => token.special = value.extract()?,
|
||||
_ => println!("Ignored unknown kwarg option {}", key),
|
||||
}
|
||||
}
|
||||
@ -161,6 +166,7 @@ impl PyAddedToken {
|
||||
"lstrip" => self.lstrip = Some(value.extract()?),
|
||||
"rstrip" => self.rstrip = Some(value.extract()?),
|
||||
"normalized" => self.normalized = Some(value.extract()?),
|
||||
"special" => self.special = value.extract()?,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
@ -176,6 +182,12 @@ impl PyAddedToken {
|
||||
&self.content
|
||||
}
|
||||
|
||||
/// Set the content of this :obj:`AddedToken`
|
||||
#[setter]
|
||||
fn set_content(&mut self, content: String) {
|
||||
self.content = content;
|
||||
}
|
||||
|
||||
/// Get the value of the :obj:`rstrip` option
|
||||
#[getter]
|
||||
fn get_rstrip(&self) -> bool {
|
||||
@ -199,6 +211,17 @@ impl PyAddedToken {
|
||||
fn get_normalized(&self) -> bool {
|
||||
self.get_token().normalized
|
||||
}
|
||||
/// Get the value of the :obj:`special` option
|
||||
#[getter]
|
||||
fn get_special(&self) -> bool {
|
||||
self.get_token().special
|
||||
}
|
||||
|
||||
/// Set the value of the :obj:`special` option
|
||||
#[setter]
|
||||
fn set_special(&mut self, special: bool) {
|
||||
self.special = special;
|
||||
}
|
||||
|
||||
fn __str__(&self) -> PyResult<&str> {
|
||||
Ok(&self.content)
|
||||
@ -212,12 +235,13 @@ impl PyAddedToken {
|
||||
|
||||
let token = self.get_token();
|
||||
Ok(format!(
|
||||
"AddedToken(\"{}\", rstrip={}, lstrip={}, single_word={}, normalized={})",
|
||||
"AddedToken(\"{}\", rstrip={}, lstrip={}, single_word={}, normalized={}, special={})",
|
||||
self.content,
|
||||
bool_to_python(token.rstrip),
|
||||
bool_to_python(token.lstrip),
|
||||
bool_to_python(token.single_word),
|
||||
bool_to_python(token.normalized)
|
||||
bool_to_python(token.normalized),
|
||||
bool_to_python(token.special)
|
||||
))
|
||||
}
|
||||
|
||||
@ -639,6 +663,22 @@ impl PyTokenizer {
|
||||
self.tokenizer.get_vocab(with_added_tokens)
|
||||
}
|
||||
|
||||
/// Get the underlying vocabulary
|
||||
///
|
||||
/// Returns:
|
||||
/// :obj:`Dict[int, AddedToken]`: The vocabulary
|
||||
#[pyo3(signature = ())]
|
||||
#[pyo3(text_signature = "(self)")]
|
||||
fn get_added_tokens_decoder(&self) -> BTreeMap<u32, PyAddedToken> {
|
||||
let mut sorted_map = BTreeMap::new();
|
||||
|
||||
for (key, value) in self.tokenizer.get_added_tokens_decoder() {
|
||||
sorted_map.insert(key, value.into());
|
||||
}
|
||||
|
||||
sorted_map
|
||||
}
|
||||
|
||||
/// Get the size of the underlying vocabulary
|
||||
///
|
||||
/// Args:
|
||||
@ -1090,7 +1130,7 @@ impl PyTokenizer {
|
||||
if let Ok(content) = token.extract::<String>() {
|
||||
Ok(PyAddedToken::from(content, Some(false)).get_token())
|
||||
} else if let Ok(mut token) = token.extract::<PyRefMut<PyAddedToken>>() {
|
||||
token.is_special_token = false;
|
||||
token.special = false;
|
||||
Ok(token.get_token())
|
||||
} else {
|
||||
Err(exceptions::PyTypeError::new_err(
|
||||
@ -1127,7 +1167,7 @@ impl PyTokenizer {
|
||||
if let Ok(content) = token.extract::<String>() {
|
||||
Ok(tk::tokenizer::AddedToken::from(content, true))
|
||||
} else if let Ok(mut token) = token.extract::<PyRefMut<PyAddedToken>>() {
|
||||
token.is_special_token = true;
|
||||
token.special = true;
|
||||
Ok(token.get_token())
|
||||
} else {
|
||||
Err(exceptions::PyTypeError::new_err(
|
||||
|
@ -226,7 +226,7 @@ impl PyBpeTrainer {
|
||||
if let Ok(content) = token.extract::<String>() {
|
||||
Ok(tk::tokenizer::AddedToken::from(content, true))
|
||||
} else if let Ok(mut token) = token.extract::<PyRefMut<PyAddedToken>>() {
|
||||
token.is_special_token = true;
|
||||
token.special = true;
|
||||
Ok(token.get_token())
|
||||
} else {
|
||||
Err(exceptions::PyTypeError::new_err(
|
||||
@ -319,7 +319,7 @@ impl PyBpeTrainer {
|
||||
} else if let Ok(mut token) =
|
||||
token.extract::<PyRefMut<PyAddedToken>>()
|
||||
{
|
||||
token.is_special_token = true;
|
||||
token.special = true;
|
||||
Ok(token.get_token())
|
||||
} else {
|
||||
Err(exceptions::PyTypeError::new_err(
|
||||
@ -440,7 +440,7 @@ impl PyWordPieceTrainer {
|
||||
if let Ok(content) = token.extract::<String>() {
|
||||
Ok(tk::tokenizer::AddedToken::from(content, true))
|
||||
} else if let Ok(mut token) = token.extract::<PyRefMut<PyAddedToken>>() {
|
||||
token.is_special_token = true;
|
||||
token.special = true;
|
||||
Ok(token.get_token())
|
||||
} else {
|
||||
Err(exceptions::PyTypeError::new_err(
|
||||
@ -526,7 +526,7 @@ impl PyWordPieceTrainer {
|
||||
} else if let Ok(mut token) =
|
||||
token.extract::<PyRefMut<PyAddedToken>>()
|
||||
{
|
||||
token.is_special_token = true;
|
||||
token.special = true;
|
||||
Ok(token.get_token())
|
||||
} else {
|
||||
Err(exceptions::PyTypeError::new_err(
|
||||
@ -632,7 +632,7 @@ impl PyWordLevelTrainer {
|
||||
if let Ok(content) = token.extract::<String>() {
|
||||
Ok(tk::tokenizer::AddedToken::from(content, true))
|
||||
} else if let Ok(mut token) = token.extract::<PyRefMut<PyAddedToken>>() {
|
||||
token.is_special_token = true;
|
||||
token.special = true;
|
||||
Ok(token.get_token())
|
||||
} else {
|
||||
Err(exceptions::PyTypeError::new_err(
|
||||
@ -673,7 +673,7 @@ impl PyWordLevelTrainer {
|
||||
} else if let Ok(mut token) =
|
||||
token.extract::<PyRefMut<PyAddedToken>>()
|
||||
{
|
||||
token.is_special_token = true;
|
||||
token.special = true;
|
||||
Ok(token.get_token())
|
||||
} else {
|
||||
Err(exceptions::PyTypeError::new_err(
|
||||
@ -778,7 +778,7 @@ impl PyUnigramTrainer {
|
||||
if let Ok(content) = token.extract::<String>() {
|
||||
Ok(tk::tokenizer::AddedToken::from(content, true))
|
||||
} else if let Ok(mut token) = token.extract::<PyRefMut<PyAddedToken>>() {
|
||||
token.is_special_token = true;
|
||||
token.special = true;
|
||||
Ok(token.get_token())
|
||||
} else {
|
||||
Err(exceptions::PyTypeError::new_err(
|
||||
@ -846,7 +846,7 @@ impl PyUnigramTrainer {
|
||||
} else if let Ok(mut token) =
|
||||
token.extract::<PyRefMut<PyAddedToken>>()
|
||||
{
|
||||
token.is_special_token = true;
|
||||
token.special = true;
|
||||
Ok(token.get_token())
|
||||
} else {
|
||||
Err(exceptions::PyTypeError::new_err(
|
||||
|
@ -16,10 +16,19 @@ from ..utils import bert_files, data_dir, multiprocessing_with_parallelism, robe
|
||||
class TestAddedToken:
|
||||
def test_instantiate_with_content_only(self):
|
||||
added_token = AddedToken("<mask>")
|
||||
added_token.content = "<MASK>"
|
||||
assert added_token.content == "<MASK>"
|
||||
assert type(added_token) == AddedToken
|
||||
added_token.content = added_token.content.lower()
|
||||
|
||||
assert added_token.special == False
|
||||
added_token.special = True
|
||||
assert added_token.special == True
|
||||
added_token.special = False
|
||||
assert str(added_token) == "<mask>"
|
||||
assert (
|
||||
repr(added_token) == 'AddedToken("<mask>", rstrip=False, lstrip=False, single_word=False, normalized=True)'
|
||||
repr(added_token)
|
||||
== 'AddedToken("<mask>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False)'
|
||||
)
|
||||
assert added_token.rstrip == False
|
||||
assert added_token.lstrip == False
|
||||
@ -365,6 +374,16 @@ class TestTokenizer:
|
||||
vocab = tokenizer.get_vocab(with_added_tokens=False)
|
||||
assert vocab == {}
|
||||
|
||||
# Can retrieve added token decoder
|
||||
vocab = tokenizer.get_added_tokens_decoder()
|
||||
assert vocab == {
|
||||
0: AddedToken("my", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),
|
||||
1: AddedToken("name", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),
|
||||
2: AddedToken("is", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),
|
||||
3: AddedToken("john", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),
|
||||
4: AddedToken("pair", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),
|
||||
}
|
||||
|
||||
def test_get_vocab_size(self):
|
||||
tokenizer = Tokenizer(BPE())
|
||||
tokenizer.add_tokens(["my", "name", "is", "john", "pair"])
|
||||
|
@ -34,8 +34,8 @@ class TestBpeTrainer:
|
||||
assert trainer.min_frequency == 12
|
||||
assert trainer.show_progress == False
|
||||
assert trainer.special_tokens == [
|
||||
AddedToken("1"),
|
||||
AddedToken("2"),
|
||||
AddedToken("1", special=True),
|
||||
AddedToken("2", special=True),
|
||||
]
|
||||
assert trainer.limit_alphabet == 13
|
||||
assert sorted(trainer.initial_alphabet) == ["a", "b", "c"]
|
||||
@ -91,8 +91,8 @@ class TestWordPieceTrainer:
|
||||
assert trainer.min_frequency == 12
|
||||
assert trainer.show_progress == False
|
||||
assert trainer.special_tokens == [
|
||||
AddedToken("1"),
|
||||
AddedToken("2"),
|
||||
AddedToken("1", special=True),
|
||||
AddedToken("2", special=True),
|
||||
]
|
||||
assert trainer.limit_alphabet == 13
|
||||
assert sorted(trainer.initial_alphabet) == ["a", "b", "c"]
|
||||
@ -131,8 +131,8 @@ class TestWordLevelTrainer:
|
||||
assert trainer.min_frequency == 12
|
||||
assert trainer.show_progress == False
|
||||
assert trainer.special_tokens == [
|
||||
AddedToken("1"),
|
||||
AddedToken("2"),
|
||||
AddedToken("1", special=True),
|
||||
AddedToken("2", special=True),
|
||||
]
|
||||
|
||||
# Modify these
|
||||
@ -272,8 +272,8 @@ class TestUnigram:
|
||||
assert trainer.vocab_size == 12345
|
||||
assert trainer.show_progress == False
|
||||
assert trainer.special_tokens == [
|
||||
AddedToken("1", normalized=False),
|
||||
AddedToken("2", lstrip=True, normalized=False),
|
||||
AddedToken("1", normalized=False, special=True),
|
||||
AddedToken("2", lstrip=True, normalized=False, special=True),
|
||||
]
|
||||
assert sorted(trainer.initial_alphabet) == ["a", "b", "c"]
|
||||
|
||||
|
@ -11,7 +11,7 @@ use std::collections::{HashMap, HashSet};
|
||||
/// like:
|
||||
/// - Whether they should only match single words
|
||||
/// - Whether to include any whitespace on its left or right
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub struct AddedToken {
|
||||
/// The content of the added token
|
||||
pub content: String,
|
||||
@ -66,6 +66,12 @@ impl AddedToken {
|
||||
self.normalized = normalized;
|
||||
self
|
||||
}
|
||||
/// Specify whether this token is special, meaning if it should be skipped when decoding
|
||||
#[must_use]
|
||||
pub fn special(mut self, special: bool) -> Self {
|
||||
self.special = special;
|
||||
self
|
||||
}
|
||||
}
|
||||
impl Default for AddedToken {
|
||||
fn default() -> Self {
|
||||
@ -79,19 +85,12 @@ impl Default for AddedToken {
|
||||
}
|
||||
}
|
||||
}
|
||||
// We only want to hash on the content. AddedToken cannot be added multiple times with different
|
||||
// options
|
||||
// AddedTokens can be updated if value changed
|
||||
impl std::hash::Hash for AddedToken {
|
||||
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
||||
self.content.hash(state);
|
||||
}
|
||||
}
|
||||
impl std::cmp::PartialEq for AddedToken {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.content == other.content
|
||||
}
|
||||
}
|
||||
impl std::cmp::Eq for AddedToken {}
|
||||
|
||||
type MatchingSet = (AhoCorasick, Vec<u32>);
|
||||
|
||||
@ -181,8 +180,8 @@ impl AddedVocabulary {
|
||||
split_normalized_trie: (normalized_trie, vec![]),
|
||||
}
|
||||
}
|
||||
|
||||
/// Size of the additional vocabulary
|
||||
#[allow(dead_code)] // Suppress the "method is never used" warning
|
||||
pub fn len(&self) -> usize {
|
||||
self.added_tokens_map.len()
|
||||
}
|
||||
@ -192,6 +191,11 @@ impl AddedVocabulary {
|
||||
&self.added_tokens_map
|
||||
}
|
||||
|
||||
/// Get the additional vocabulary with the AddedTokens
|
||||
pub fn get_added_tokens_decoder(&self) -> &HashMap<u32, AddedToken> {
|
||||
&self.added_tokens_map_r
|
||||
}
|
||||
|
||||
/// Get the id matching one of our token if it exists
|
||||
pub fn token_to_id(&self, token: &str, model: &impl Model) -> Option<u32> {
|
||||
self.added_tokens_map
|
||||
@ -244,30 +248,42 @@ impl AddedVocabulary {
|
||||
// Then we delegate to `add_tokens`, that will take care of refreshing added tokens too.
|
||||
let mut ignored = 0;
|
||||
for token in tokens {
|
||||
if token.content.is_empty() {
|
||||
if token.content.is_empty() || self.added_tokens_map_r.values().any(|val| val == token)
|
||||
{
|
||||
ignored += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
let id = if let Some(id) = self.token_to_id(&token.content, model) {
|
||||
ignored += 1;
|
||||
id
|
||||
} else {
|
||||
let new_id = (model.get_vocab_size() + self.added_tokens_map.len()) as u32;
|
||||
self.added_tokens_map.insert(token.content.clone(), new_id);
|
||||
|
||||
if !self.special_tokens_set.contains(&token.content) {
|
||||
self.added_tokens.push(token.clone());
|
||||
}
|
||||
|
||||
// If a token is already part of the vocabulary, we mark it as added
|
||||
let new_id = if let Some(new_id) = self.token_to_id(&token.content, model) {
|
||||
new_id
|
||||
} else {
|
||||
self.added_tokens_map.values().cloned().max().map_or(
|
||||
model.get_vocab_size() as u32,
|
||||
|max| {
|
||||
if (max >= model.get_vocab_size() as u32) || model.get_vocab_size() == 0 {
|
||||
max + 1
|
||||
} else {
|
||||
model.get_vocab_size() as u32
|
||||
}
|
||||
},
|
||||
)
|
||||
};
|
||||
|
||||
// Make sure we modify the previous entry
|
||||
self.added_tokens_map
|
||||
.entry(token.content.clone())
|
||||
.and_modify(|old_id| *old_id = new_id)
|
||||
.or_insert_with(|| new_id);
|
||||
// Update the current revert operation
|
||||
self.added_tokens_map_r
|
||||
.entry(id)
|
||||
.entry(new_id)
|
||||
.and_modify(|t| *t = token.clone())
|
||||
.or_insert_with(|| token.clone());
|
||||
// Make sure to remove previous entry (if the token gets a new id)
|
||||
|
||||
// Finally add the token to the classic set if special
|
||||
if !self.special_tokens_set.contains(&token.content) {
|
||||
self.added_tokens.push(token.clone());
|
||||
}
|
||||
}
|
||||
|
||||
self.refresh_added_tokens(model, normalizer);
|
||||
@ -569,7 +585,9 @@ mod tests {
|
||||
),
|
||||
1
|
||||
);
|
||||
assert_eq!(vocab.len(), 1);
|
||||
|
||||
let vocab_len: usize = vocab.len();
|
||||
assert_eq!(vocab_len, 1);
|
||||
|
||||
// Does not add multiple time the same token
|
||||
assert_eq!(
|
||||
@ -585,12 +603,15 @@ mod tests {
|
||||
);
|
||||
assert_eq!(vocab.len(), 2);
|
||||
|
||||
// Does not add tokens already covered by the model
|
||||
// Also adds tokens already covered by the model
|
||||
let added_token = AddedToken::from("test", false);
|
||||
assert_eq!(
|
||||
vocab.add_tokens(&[AddedToken::from("test", false)], &model, normalizer),
|
||||
0
|
||||
vocab.add_tokens(&[added_token.clone()], &model, normalizer),
|
||||
1
|
||||
);
|
||||
assert_eq!(vocab.len(), 2);
|
||||
assert_eq!(vocab.len(), 3);
|
||||
|
||||
assert_eq!(vocab.get_added_tokens_decoder()[&0], added_token);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -626,11 +647,47 @@ mod tests {
|
||||
// Can add tokens already covered by the model
|
||||
assert_eq!(
|
||||
vocab.add_special_tokens(&[AddedToken::from("test", true)], &model, normalizer),
|
||||
0
|
||||
1
|
||||
);
|
||||
assert_eq!(vocab.len(), 2); // Did not add a new token, since it exist in the original model
|
||||
assert_eq!(vocab.len(), 3); // New token was added
|
||||
assert!(vocab.is_special_token("test"));
|
||||
assert!(!vocab.added_tokens_map.contains_key("test"));
|
||||
assert_eq!(
|
||||
*vocab.get_added_tokens_decoder(),
|
||||
HashMap::from([
|
||||
(0, AddedToken::from("test", true)),
|
||||
(2, AddedToken::from("added_token_1", true)),
|
||||
(3, AddedToken::from("added_token_2", true)),
|
||||
])
|
||||
);
|
||||
assert!(vocab.added_tokens_map.contains_key("test"));
|
||||
assert!(vocab.added_tokens_map_r.contains_key(&0));
|
||||
|
||||
vocab.add_tokens(
|
||||
&[
|
||||
AddedToken::from("tost", true),
|
||||
AddedToken::from("another_two", false),
|
||||
],
|
||||
&model,
|
||||
normalizer,
|
||||
);
|
||||
assert_eq!(vocab.len(), 5); // New token was added
|
||||
assert_eq!(vocab.get_vocab()["another_two"], 4); // New token was added, but the index is not the length of the vocab
|
||||
|
||||
// Let's add an already added token again
|
||||
assert_eq!(
|
||||
vocab.add_special_tokens(&[AddedToken::from("another_two", true)], &model, normalizer),
|
||||
1
|
||||
);
|
||||
assert_eq!(vocab.len(), 5); // Token was already there
|
||||
assert_eq!(vocab.get_vocab()["another_two"], 4); // Token idx not changed
|
||||
|
||||
// Just checking that we can set the content of the string in rust
|
||||
let mut token: AddedToken = AddedToken::from("Hey", false);
|
||||
token.content = "hey".to_string();
|
||||
assert_eq!(token.content, "hey"); // Token was already there
|
||||
|
||||
token.special = true;
|
||||
assert!(token.special); // Token was already there
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -766,6 +823,8 @@ mod tests {
|
||||
let mut vocab = AddedVocabulary::new();
|
||||
let normalizer = Lowercase;
|
||||
|
||||
assert_eq!(vocab.len(), 0);
|
||||
|
||||
vocab.add_tokens(
|
||||
&[AddedToken::from("<mask>", false).single_word(true)],
|
||||
&model,
|
||||
|
@ -659,14 +659,20 @@ where
|
||||
final_vocab
|
||||
}
|
||||
|
||||
/// Get the added tokens decoder
|
||||
pub fn get_added_tokens_decoder(&self) -> HashMap<u32, AddedToken> {
|
||||
self.added_vocabulary.get_added_tokens_decoder().clone()
|
||||
}
|
||||
|
||||
/// Get the size of the vocabulary
|
||||
pub fn get_vocab_size(&self, with_added_tokens: bool) -> usize {
|
||||
self.model.get_vocab_size()
|
||||
+ if with_added_tokens {
|
||||
self.added_vocabulary.len()
|
||||
} else {
|
||||
0
|
||||
}
|
||||
// TODO ArthurZ THIS IS WRONG! We need to measure the length of the `set` because
|
||||
// now some tokens can be both in the added_tokens_encoder and in the vocab
|
||||
if with_added_tokens {
|
||||
self.get_vocab(true).len()
|
||||
} else {
|
||||
self.model.get_vocab_size()
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts a token in the corresponding id.
|
||||
|
Reference in New Issue
Block a user