mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-30 20:19:21 +00:00
Merge pull request #1335 from ArthurZucker/update-added-tokens
Update added tokens
This commit is contained in:
@ -30,10 +30,12 @@ class AddedToken:
|
|||||||
text. For example, with the added token ``"yesterday"``, and a normalizer in charge of
|
text. For example, with the added token ``"yesterday"``, and a normalizer in charge of
|
||||||
lowercasing the text, the token could be extract from the input ``"I saw a lion
|
lowercasing the text, the token could be extract from the input ``"I saw a lion
|
||||||
Yesterday"``.
|
Yesterday"``.
|
||||||
|
special (:obj:`bool`, defaults to :obj:`False` with :meth:`~tokenizers.Tokenizer.add_tokens` and :obj:`False` with :meth:`~tokenizers.Tokenizer.add_special_tokens`):
|
||||||
|
Defines whether this token should be skipped when decoding.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, content, single_word=False, lstrip=False, rstrip=False, normalized=True):
|
def __init__(self, content, single_word=False, lstrip=False, rstrip=False, normalized=True, special=False):
|
||||||
pass
|
pass
|
||||||
@property
|
@property
|
||||||
def content(self):
|
def content(self):
|
||||||
@ -65,6 +67,12 @@ class AddedToken:
|
|||||||
Get the value of the :obj:`single_word` option
|
Get the value of the :obj:`single_word` option
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
@property
|
||||||
|
def special(self):
|
||||||
|
"""
|
||||||
|
Get the value of the :obj:`special` option
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
class Encoding:
|
class Encoding:
|
||||||
"""
|
"""
|
||||||
@ -891,6 +899,14 @@ class Tokenizer:
|
|||||||
:class:`~tokenizers.Tokenizer`: The new tokenizer
|
:class:`~tokenizers.Tokenizer`: The new tokenizer
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
def get_added_tokens_decoder(self):
|
||||||
|
"""
|
||||||
|
Get the underlying vocabulary
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:obj:`Dict[int, AddedToken]`: The vocabulary
|
||||||
|
"""
|
||||||
|
pass
|
||||||
def get_vocab(self, with_added_tokens=True):
|
def get_vocab(self, with_added_tokens=True):
|
||||||
"""
|
"""
|
||||||
Get the underlying vocabulary
|
Get the underlying vocabulary
|
||||||
|
@ -42,6 +42,14 @@ class BaseTokenizer:
|
|||||||
"""
|
"""
|
||||||
return self._tokenizer.get_vocab(with_added_tokens=with_added_tokens)
|
return self._tokenizer.get_vocab(with_added_tokens=with_added_tokens)
|
||||||
|
|
||||||
|
def get_added_tokens_decoder(self) -> Dict[int, AddedToken]:
|
||||||
|
"""Returns the added reverse vocabulary
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The added vocabulary mapping ints to AddedTokens
|
||||||
|
"""
|
||||||
|
return self._tokenizer.get_added_tokens_decoder()
|
||||||
|
|
||||||
def get_vocab_size(self, with_added_tokens: bool = True) -> int:
|
def get_vocab_size(self, with_added_tokens: bool = True) -> int:
|
||||||
"""Return the size of vocabulary, with or without added tokens.
|
"""Return the size of vocabulary, with or without added tokens.
|
||||||
|
|
||||||
|
@ -25,6 +25,7 @@ use super::pre_tokenizers::PyPreTokenizer;
|
|||||||
use super::trainers::PyTrainer;
|
use super::trainers::PyTrainer;
|
||||||
use crate::processors::PyPostProcessor;
|
use crate::processors::PyPostProcessor;
|
||||||
use crate::utils::{MaybeSizedIterator, PyBufferedIterator};
|
use crate::utils::{MaybeSizedIterator, PyBufferedIterator};
|
||||||
|
use std::collections::BTreeMap;
|
||||||
|
|
||||||
/// Represents a token that can be be added to a :class:`~tokenizers.Tokenizer`.
|
/// Represents a token that can be be added to a :class:`~tokenizers.Tokenizer`.
|
||||||
/// It can have special options that defines the way it should behave.
|
/// It can have special options that defines the way it should behave.
|
||||||
@ -55,21 +56,23 @@ use crate::utils::{MaybeSizedIterator, PyBufferedIterator};
|
|||||||
/// text. For example, with the added token ``"yesterday"``, and a normalizer in charge of
|
/// text. For example, with the added token ``"yesterday"``, and a normalizer in charge of
|
||||||
/// lowercasing the text, the token could be extract from the input ``"I saw a lion
|
/// lowercasing the text, the token could be extract from the input ``"I saw a lion
|
||||||
/// Yesterday"``.
|
/// Yesterday"``.
|
||||||
|
/// special (:obj:`bool`, defaults to :obj:`False` with :meth:`~tokenizers.Tokenizer.add_tokens` and :obj:`False` with :meth:`~tokenizers.Tokenizer.add_special_tokens`):
|
||||||
|
/// Defines whether this token should be skipped when decoding.
|
||||||
///
|
///
|
||||||
#[pyclass(dict, module = "tokenizers", name = "AddedToken")]
|
#[pyclass(dict, module = "tokenizers", name = "AddedToken")]
|
||||||
pub struct PyAddedToken {
|
pub struct PyAddedToken {
|
||||||
pub content: String,
|
pub content: String,
|
||||||
pub is_special_token: bool,
|
pub special: bool,
|
||||||
pub single_word: Option<bool>,
|
pub single_word: Option<bool>,
|
||||||
pub lstrip: Option<bool>,
|
pub lstrip: Option<bool>,
|
||||||
pub rstrip: Option<bool>,
|
pub rstrip: Option<bool>,
|
||||||
pub normalized: Option<bool>,
|
pub normalized: Option<bool>,
|
||||||
}
|
}
|
||||||
impl PyAddedToken {
|
impl PyAddedToken {
|
||||||
pub fn from<S: Into<String>>(content: S, is_special_token: Option<bool>) -> Self {
|
pub fn from<S: Into<String>>(content: S, special: Option<bool>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
content: content.into(),
|
content: content.into(),
|
||||||
is_special_token: is_special_token.unwrap_or(false),
|
special: special.unwrap_or(false),
|
||||||
single_word: None,
|
single_word: None,
|
||||||
lstrip: None,
|
lstrip: None,
|
||||||
rstrip: None,
|
rstrip: None,
|
||||||
@ -78,7 +81,7 @@ impl PyAddedToken {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_token(&self) -> tk::tokenizer::AddedToken {
|
pub fn get_token(&self) -> tk::tokenizer::AddedToken {
|
||||||
let mut token = tk::AddedToken::from(&self.content, self.is_special_token);
|
let mut token = tk::AddedToken::from(&self.content, self.special);
|
||||||
|
|
||||||
if let Some(sw) = self.single_word {
|
if let Some(sw) = self.single_word {
|
||||||
token = token.single_word(sw);
|
token = token.single_word(sw);
|
||||||
@ -105,6 +108,7 @@ impl PyAddedToken {
|
|||||||
dict.set_item("lstrip", token.lstrip)?;
|
dict.set_item("lstrip", token.lstrip)?;
|
||||||
dict.set_item("rstrip", token.rstrip)?;
|
dict.set_item("rstrip", token.rstrip)?;
|
||||||
dict.set_item("normalized", token.normalized)?;
|
dict.set_item("normalized", token.normalized)?;
|
||||||
|
dict.set_item("special", token.special)?;
|
||||||
|
|
||||||
Ok(dict)
|
Ok(dict)
|
||||||
}
|
}
|
||||||
@ -118,7 +122,7 @@ impl From<tk::AddedToken> for PyAddedToken {
|
|||||||
lstrip: Some(token.lstrip),
|
lstrip: Some(token.lstrip),
|
||||||
rstrip: Some(token.rstrip),
|
rstrip: Some(token.rstrip),
|
||||||
normalized: Some(token.normalized),
|
normalized: Some(token.normalized),
|
||||||
is_special_token: !token.normalized,
|
special: token.special,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -126,7 +130,7 @@ impl From<tk::AddedToken> for PyAddedToken {
|
|||||||
#[pymethods]
|
#[pymethods]
|
||||||
impl PyAddedToken {
|
impl PyAddedToken {
|
||||||
#[new]
|
#[new]
|
||||||
#[pyo3(signature = (content=None, **kwargs), text_signature = "(self, content, single_word=False, lstrip=False, rstrip=False, normalized=True)")]
|
#[pyo3(signature = (content=None, **kwargs), text_signature = "(self, content, single_word=False, lstrip=False, rstrip=False, normalized=True, special=False)")]
|
||||||
fn __new__(content: Option<&str>, kwargs: Option<&PyDict>) -> PyResult<Self> {
|
fn __new__(content: Option<&str>, kwargs: Option<&PyDict>) -> PyResult<Self> {
|
||||||
let mut token = PyAddedToken::from(content.unwrap_or(""), None);
|
let mut token = PyAddedToken::from(content.unwrap_or(""), None);
|
||||||
|
|
||||||
@ -138,6 +142,7 @@ impl PyAddedToken {
|
|||||||
"lstrip" => token.lstrip = Some(value.extract()?),
|
"lstrip" => token.lstrip = Some(value.extract()?),
|
||||||
"rstrip" => token.rstrip = Some(value.extract()?),
|
"rstrip" => token.rstrip = Some(value.extract()?),
|
||||||
"normalized" => token.normalized = Some(value.extract()?),
|
"normalized" => token.normalized = Some(value.extract()?),
|
||||||
|
"special" => token.special = value.extract()?,
|
||||||
_ => println!("Ignored unknown kwarg option {}", key),
|
_ => println!("Ignored unknown kwarg option {}", key),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -161,6 +166,7 @@ impl PyAddedToken {
|
|||||||
"lstrip" => self.lstrip = Some(value.extract()?),
|
"lstrip" => self.lstrip = Some(value.extract()?),
|
||||||
"rstrip" => self.rstrip = Some(value.extract()?),
|
"rstrip" => self.rstrip = Some(value.extract()?),
|
||||||
"normalized" => self.normalized = Some(value.extract()?),
|
"normalized" => self.normalized = Some(value.extract()?),
|
||||||
|
"special" => self.special = value.extract()?,
|
||||||
_ => {}
|
_ => {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -176,6 +182,12 @@ impl PyAddedToken {
|
|||||||
&self.content
|
&self.content
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Set the content of this :obj:`AddedToken`
|
||||||
|
#[setter]
|
||||||
|
fn set_content(&mut self, content: String) {
|
||||||
|
self.content = content;
|
||||||
|
}
|
||||||
|
|
||||||
/// Get the value of the :obj:`rstrip` option
|
/// Get the value of the :obj:`rstrip` option
|
||||||
#[getter]
|
#[getter]
|
||||||
fn get_rstrip(&self) -> bool {
|
fn get_rstrip(&self) -> bool {
|
||||||
@ -199,6 +211,17 @@ impl PyAddedToken {
|
|||||||
fn get_normalized(&self) -> bool {
|
fn get_normalized(&self) -> bool {
|
||||||
self.get_token().normalized
|
self.get_token().normalized
|
||||||
}
|
}
|
||||||
|
/// Get the value of the :obj:`special` option
|
||||||
|
#[getter]
|
||||||
|
fn get_special(&self) -> bool {
|
||||||
|
self.get_token().special
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set the value of the :obj:`special` option
|
||||||
|
#[setter]
|
||||||
|
fn set_special(&mut self, special: bool) {
|
||||||
|
self.special = special;
|
||||||
|
}
|
||||||
|
|
||||||
fn __str__(&self) -> PyResult<&str> {
|
fn __str__(&self) -> PyResult<&str> {
|
||||||
Ok(&self.content)
|
Ok(&self.content)
|
||||||
@ -212,12 +235,13 @@ impl PyAddedToken {
|
|||||||
|
|
||||||
let token = self.get_token();
|
let token = self.get_token();
|
||||||
Ok(format!(
|
Ok(format!(
|
||||||
"AddedToken(\"{}\", rstrip={}, lstrip={}, single_word={}, normalized={})",
|
"AddedToken(\"{}\", rstrip={}, lstrip={}, single_word={}, normalized={}, special={})",
|
||||||
self.content,
|
self.content,
|
||||||
bool_to_python(token.rstrip),
|
bool_to_python(token.rstrip),
|
||||||
bool_to_python(token.lstrip),
|
bool_to_python(token.lstrip),
|
||||||
bool_to_python(token.single_word),
|
bool_to_python(token.single_word),
|
||||||
bool_to_python(token.normalized)
|
bool_to_python(token.normalized),
|
||||||
|
bool_to_python(token.special)
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -639,6 +663,22 @@ impl PyTokenizer {
|
|||||||
self.tokenizer.get_vocab(with_added_tokens)
|
self.tokenizer.get_vocab(with_added_tokens)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Get the underlying vocabulary
|
||||||
|
///
|
||||||
|
/// Returns:
|
||||||
|
/// :obj:`Dict[int, AddedToken]`: The vocabulary
|
||||||
|
#[pyo3(signature = ())]
|
||||||
|
#[pyo3(text_signature = "(self)")]
|
||||||
|
fn get_added_tokens_decoder(&self) -> BTreeMap<u32, PyAddedToken> {
|
||||||
|
let mut sorted_map = BTreeMap::new();
|
||||||
|
|
||||||
|
for (key, value) in self.tokenizer.get_added_tokens_decoder() {
|
||||||
|
sorted_map.insert(key, value.into());
|
||||||
|
}
|
||||||
|
|
||||||
|
sorted_map
|
||||||
|
}
|
||||||
|
|
||||||
/// Get the size of the underlying vocabulary
|
/// Get the size of the underlying vocabulary
|
||||||
///
|
///
|
||||||
/// Args:
|
/// Args:
|
||||||
@ -1090,7 +1130,7 @@ impl PyTokenizer {
|
|||||||
if let Ok(content) = token.extract::<String>() {
|
if let Ok(content) = token.extract::<String>() {
|
||||||
Ok(PyAddedToken::from(content, Some(false)).get_token())
|
Ok(PyAddedToken::from(content, Some(false)).get_token())
|
||||||
} else if let Ok(mut token) = token.extract::<PyRefMut<PyAddedToken>>() {
|
} else if let Ok(mut token) = token.extract::<PyRefMut<PyAddedToken>>() {
|
||||||
token.is_special_token = false;
|
token.special = false;
|
||||||
Ok(token.get_token())
|
Ok(token.get_token())
|
||||||
} else {
|
} else {
|
||||||
Err(exceptions::PyTypeError::new_err(
|
Err(exceptions::PyTypeError::new_err(
|
||||||
@ -1127,7 +1167,7 @@ impl PyTokenizer {
|
|||||||
if let Ok(content) = token.extract::<String>() {
|
if let Ok(content) = token.extract::<String>() {
|
||||||
Ok(tk::tokenizer::AddedToken::from(content, true))
|
Ok(tk::tokenizer::AddedToken::from(content, true))
|
||||||
} else if let Ok(mut token) = token.extract::<PyRefMut<PyAddedToken>>() {
|
} else if let Ok(mut token) = token.extract::<PyRefMut<PyAddedToken>>() {
|
||||||
token.is_special_token = true;
|
token.special = true;
|
||||||
Ok(token.get_token())
|
Ok(token.get_token())
|
||||||
} else {
|
} else {
|
||||||
Err(exceptions::PyTypeError::new_err(
|
Err(exceptions::PyTypeError::new_err(
|
||||||
|
@ -226,7 +226,7 @@ impl PyBpeTrainer {
|
|||||||
if let Ok(content) = token.extract::<String>() {
|
if let Ok(content) = token.extract::<String>() {
|
||||||
Ok(tk::tokenizer::AddedToken::from(content, true))
|
Ok(tk::tokenizer::AddedToken::from(content, true))
|
||||||
} else if let Ok(mut token) = token.extract::<PyRefMut<PyAddedToken>>() {
|
} else if let Ok(mut token) = token.extract::<PyRefMut<PyAddedToken>>() {
|
||||||
token.is_special_token = true;
|
token.special = true;
|
||||||
Ok(token.get_token())
|
Ok(token.get_token())
|
||||||
} else {
|
} else {
|
||||||
Err(exceptions::PyTypeError::new_err(
|
Err(exceptions::PyTypeError::new_err(
|
||||||
@ -319,7 +319,7 @@ impl PyBpeTrainer {
|
|||||||
} else if let Ok(mut token) =
|
} else if let Ok(mut token) =
|
||||||
token.extract::<PyRefMut<PyAddedToken>>()
|
token.extract::<PyRefMut<PyAddedToken>>()
|
||||||
{
|
{
|
||||||
token.is_special_token = true;
|
token.special = true;
|
||||||
Ok(token.get_token())
|
Ok(token.get_token())
|
||||||
} else {
|
} else {
|
||||||
Err(exceptions::PyTypeError::new_err(
|
Err(exceptions::PyTypeError::new_err(
|
||||||
@ -440,7 +440,7 @@ impl PyWordPieceTrainer {
|
|||||||
if let Ok(content) = token.extract::<String>() {
|
if let Ok(content) = token.extract::<String>() {
|
||||||
Ok(tk::tokenizer::AddedToken::from(content, true))
|
Ok(tk::tokenizer::AddedToken::from(content, true))
|
||||||
} else if let Ok(mut token) = token.extract::<PyRefMut<PyAddedToken>>() {
|
} else if let Ok(mut token) = token.extract::<PyRefMut<PyAddedToken>>() {
|
||||||
token.is_special_token = true;
|
token.special = true;
|
||||||
Ok(token.get_token())
|
Ok(token.get_token())
|
||||||
} else {
|
} else {
|
||||||
Err(exceptions::PyTypeError::new_err(
|
Err(exceptions::PyTypeError::new_err(
|
||||||
@ -526,7 +526,7 @@ impl PyWordPieceTrainer {
|
|||||||
} else if let Ok(mut token) =
|
} else if let Ok(mut token) =
|
||||||
token.extract::<PyRefMut<PyAddedToken>>()
|
token.extract::<PyRefMut<PyAddedToken>>()
|
||||||
{
|
{
|
||||||
token.is_special_token = true;
|
token.special = true;
|
||||||
Ok(token.get_token())
|
Ok(token.get_token())
|
||||||
} else {
|
} else {
|
||||||
Err(exceptions::PyTypeError::new_err(
|
Err(exceptions::PyTypeError::new_err(
|
||||||
@ -632,7 +632,7 @@ impl PyWordLevelTrainer {
|
|||||||
if let Ok(content) = token.extract::<String>() {
|
if let Ok(content) = token.extract::<String>() {
|
||||||
Ok(tk::tokenizer::AddedToken::from(content, true))
|
Ok(tk::tokenizer::AddedToken::from(content, true))
|
||||||
} else if let Ok(mut token) = token.extract::<PyRefMut<PyAddedToken>>() {
|
} else if let Ok(mut token) = token.extract::<PyRefMut<PyAddedToken>>() {
|
||||||
token.is_special_token = true;
|
token.special = true;
|
||||||
Ok(token.get_token())
|
Ok(token.get_token())
|
||||||
} else {
|
} else {
|
||||||
Err(exceptions::PyTypeError::new_err(
|
Err(exceptions::PyTypeError::new_err(
|
||||||
@ -673,7 +673,7 @@ impl PyWordLevelTrainer {
|
|||||||
} else if let Ok(mut token) =
|
} else if let Ok(mut token) =
|
||||||
token.extract::<PyRefMut<PyAddedToken>>()
|
token.extract::<PyRefMut<PyAddedToken>>()
|
||||||
{
|
{
|
||||||
token.is_special_token = true;
|
token.special = true;
|
||||||
Ok(token.get_token())
|
Ok(token.get_token())
|
||||||
} else {
|
} else {
|
||||||
Err(exceptions::PyTypeError::new_err(
|
Err(exceptions::PyTypeError::new_err(
|
||||||
@ -778,7 +778,7 @@ impl PyUnigramTrainer {
|
|||||||
if let Ok(content) = token.extract::<String>() {
|
if let Ok(content) = token.extract::<String>() {
|
||||||
Ok(tk::tokenizer::AddedToken::from(content, true))
|
Ok(tk::tokenizer::AddedToken::from(content, true))
|
||||||
} else if let Ok(mut token) = token.extract::<PyRefMut<PyAddedToken>>() {
|
} else if let Ok(mut token) = token.extract::<PyRefMut<PyAddedToken>>() {
|
||||||
token.is_special_token = true;
|
token.special = true;
|
||||||
Ok(token.get_token())
|
Ok(token.get_token())
|
||||||
} else {
|
} else {
|
||||||
Err(exceptions::PyTypeError::new_err(
|
Err(exceptions::PyTypeError::new_err(
|
||||||
@ -846,7 +846,7 @@ impl PyUnigramTrainer {
|
|||||||
} else if let Ok(mut token) =
|
} else if let Ok(mut token) =
|
||||||
token.extract::<PyRefMut<PyAddedToken>>()
|
token.extract::<PyRefMut<PyAddedToken>>()
|
||||||
{
|
{
|
||||||
token.is_special_token = true;
|
token.special = true;
|
||||||
Ok(token.get_token())
|
Ok(token.get_token())
|
||||||
} else {
|
} else {
|
||||||
Err(exceptions::PyTypeError::new_err(
|
Err(exceptions::PyTypeError::new_err(
|
||||||
|
@ -16,10 +16,19 @@ from ..utils import bert_files, data_dir, multiprocessing_with_parallelism, robe
|
|||||||
class TestAddedToken:
|
class TestAddedToken:
|
||||||
def test_instantiate_with_content_only(self):
|
def test_instantiate_with_content_only(self):
|
||||||
added_token = AddedToken("<mask>")
|
added_token = AddedToken("<mask>")
|
||||||
|
added_token.content = "<MASK>"
|
||||||
|
assert added_token.content == "<MASK>"
|
||||||
assert type(added_token) == AddedToken
|
assert type(added_token) == AddedToken
|
||||||
|
added_token.content = added_token.content.lower()
|
||||||
|
|
||||||
|
assert added_token.special == False
|
||||||
|
added_token.special = True
|
||||||
|
assert added_token.special == True
|
||||||
|
added_token.special = False
|
||||||
assert str(added_token) == "<mask>"
|
assert str(added_token) == "<mask>"
|
||||||
assert (
|
assert (
|
||||||
repr(added_token) == 'AddedToken("<mask>", rstrip=False, lstrip=False, single_word=False, normalized=True)'
|
repr(added_token)
|
||||||
|
== 'AddedToken("<mask>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False)'
|
||||||
)
|
)
|
||||||
assert added_token.rstrip == False
|
assert added_token.rstrip == False
|
||||||
assert added_token.lstrip == False
|
assert added_token.lstrip == False
|
||||||
@ -365,6 +374,16 @@ class TestTokenizer:
|
|||||||
vocab = tokenizer.get_vocab(with_added_tokens=False)
|
vocab = tokenizer.get_vocab(with_added_tokens=False)
|
||||||
assert vocab == {}
|
assert vocab == {}
|
||||||
|
|
||||||
|
# Can retrieve added token decoder
|
||||||
|
vocab = tokenizer.get_added_tokens_decoder()
|
||||||
|
assert vocab == {
|
||||||
|
0: AddedToken("my", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),
|
||||||
|
1: AddedToken("name", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),
|
||||||
|
2: AddedToken("is", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),
|
||||||
|
3: AddedToken("john", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),
|
||||||
|
4: AddedToken("pair", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),
|
||||||
|
}
|
||||||
|
|
||||||
def test_get_vocab_size(self):
|
def test_get_vocab_size(self):
|
||||||
tokenizer = Tokenizer(BPE())
|
tokenizer = Tokenizer(BPE())
|
||||||
tokenizer.add_tokens(["my", "name", "is", "john", "pair"])
|
tokenizer.add_tokens(["my", "name", "is", "john", "pair"])
|
||||||
|
@ -34,8 +34,8 @@ class TestBpeTrainer:
|
|||||||
assert trainer.min_frequency == 12
|
assert trainer.min_frequency == 12
|
||||||
assert trainer.show_progress == False
|
assert trainer.show_progress == False
|
||||||
assert trainer.special_tokens == [
|
assert trainer.special_tokens == [
|
||||||
AddedToken("1"),
|
AddedToken("1", special=True),
|
||||||
AddedToken("2"),
|
AddedToken("2", special=True),
|
||||||
]
|
]
|
||||||
assert trainer.limit_alphabet == 13
|
assert trainer.limit_alphabet == 13
|
||||||
assert sorted(trainer.initial_alphabet) == ["a", "b", "c"]
|
assert sorted(trainer.initial_alphabet) == ["a", "b", "c"]
|
||||||
@ -91,8 +91,8 @@ class TestWordPieceTrainer:
|
|||||||
assert trainer.min_frequency == 12
|
assert trainer.min_frequency == 12
|
||||||
assert trainer.show_progress == False
|
assert trainer.show_progress == False
|
||||||
assert trainer.special_tokens == [
|
assert trainer.special_tokens == [
|
||||||
AddedToken("1"),
|
AddedToken("1", special=True),
|
||||||
AddedToken("2"),
|
AddedToken("2", special=True),
|
||||||
]
|
]
|
||||||
assert trainer.limit_alphabet == 13
|
assert trainer.limit_alphabet == 13
|
||||||
assert sorted(trainer.initial_alphabet) == ["a", "b", "c"]
|
assert sorted(trainer.initial_alphabet) == ["a", "b", "c"]
|
||||||
@ -131,8 +131,8 @@ class TestWordLevelTrainer:
|
|||||||
assert trainer.min_frequency == 12
|
assert trainer.min_frequency == 12
|
||||||
assert trainer.show_progress == False
|
assert trainer.show_progress == False
|
||||||
assert trainer.special_tokens == [
|
assert trainer.special_tokens == [
|
||||||
AddedToken("1"),
|
AddedToken("1", special=True),
|
||||||
AddedToken("2"),
|
AddedToken("2", special=True),
|
||||||
]
|
]
|
||||||
|
|
||||||
# Modify these
|
# Modify these
|
||||||
@ -272,8 +272,8 @@ class TestUnigram:
|
|||||||
assert trainer.vocab_size == 12345
|
assert trainer.vocab_size == 12345
|
||||||
assert trainer.show_progress == False
|
assert trainer.show_progress == False
|
||||||
assert trainer.special_tokens == [
|
assert trainer.special_tokens == [
|
||||||
AddedToken("1", normalized=False),
|
AddedToken("1", normalized=False, special=True),
|
||||||
AddedToken("2", lstrip=True, normalized=False),
|
AddedToken("2", lstrip=True, normalized=False, special=True),
|
||||||
]
|
]
|
||||||
assert sorted(trainer.initial_alphabet) == ["a", "b", "c"]
|
assert sorted(trainer.initial_alphabet) == ["a", "b", "c"]
|
||||||
|
|
||||||
|
@ -11,7 +11,7 @@ use std::collections::{HashMap, HashSet};
|
|||||||
/// like:
|
/// like:
|
||||||
/// - Whether they should only match single words
|
/// - Whether they should only match single words
|
||||||
/// - Whether to include any whitespace on its left or right
|
/// - Whether to include any whitespace on its left or right
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||||
pub struct AddedToken {
|
pub struct AddedToken {
|
||||||
/// The content of the added token
|
/// The content of the added token
|
||||||
pub content: String,
|
pub content: String,
|
||||||
@ -66,6 +66,12 @@ impl AddedToken {
|
|||||||
self.normalized = normalized;
|
self.normalized = normalized;
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
/// Specify whether this token is special, meaning if it should be skipped when decoding
|
||||||
|
#[must_use]
|
||||||
|
pub fn special(mut self, special: bool) -> Self {
|
||||||
|
self.special = special;
|
||||||
|
self
|
||||||
|
}
|
||||||
}
|
}
|
||||||
impl Default for AddedToken {
|
impl Default for AddedToken {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
@ -79,19 +85,12 @@ impl Default for AddedToken {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// We only want to hash on the content. AddedToken cannot be added multiple times with different
|
// AddedTokens can be updated if value changed
|
||||||
// options
|
|
||||||
impl std::hash::Hash for AddedToken {
|
impl std::hash::Hash for AddedToken {
|
||||||
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
||||||
self.content.hash(state);
|
self.content.hash(state);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
impl std::cmp::PartialEq for AddedToken {
|
|
||||||
fn eq(&self, other: &Self) -> bool {
|
|
||||||
self.content == other.content
|
|
||||||
}
|
|
||||||
}
|
|
||||||
impl std::cmp::Eq for AddedToken {}
|
|
||||||
|
|
||||||
type MatchingSet = (AhoCorasick, Vec<u32>);
|
type MatchingSet = (AhoCorasick, Vec<u32>);
|
||||||
|
|
||||||
@ -181,8 +180,8 @@ impl AddedVocabulary {
|
|||||||
split_normalized_trie: (normalized_trie, vec![]),
|
split_normalized_trie: (normalized_trie, vec![]),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Size of the additional vocabulary
|
/// Size of the additional vocabulary
|
||||||
|
#[allow(dead_code)] // Suppress the "method is never used" warning
|
||||||
pub fn len(&self) -> usize {
|
pub fn len(&self) -> usize {
|
||||||
self.added_tokens_map.len()
|
self.added_tokens_map.len()
|
||||||
}
|
}
|
||||||
@ -192,6 +191,11 @@ impl AddedVocabulary {
|
|||||||
&self.added_tokens_map
|
&self.added_tokens_map
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Get the additional vocabulary with the AddedTokens
|
||||||
|
pub fn get_added_tokens_decoder(&self) -> &HashMap<u32, AddedToken> {
|
||||||
|
&self.added_tokens_map_r
|
||||||
|
}
|
||||||
|
|
||||||
/// Get the id matching one of our token if it exists
|
/// Get the id matching one of our token if it exists
|
||||||
pub fn token_to_id(&self, token: &str, model: &impl Model) -> Option<u32> {
|
pub fn token_to_id(&self, token: &str, model: &impl Model) -> Option<u32> {
|
||||||
self.added_tokens_map
|
self.added_tokens_map
|
||||||
@ -244,30 +248,42 @@ impl AddedVocabulary {
|
|||||||
// Then we delegate to `add_tokens`, that will take care of refreshing added tokens too.
|
// Then we delegate to `add_tokens`, that will take care of refreshing added tokens too.
|
||||||
let mut ignored = 0;
|
let mut ignored = 0;
|
||||||
for token in tokens {
|
for token in tokens {
|
||||||
if token.content.is_empty() {
|
if token.content.is_empty() || self.added_tokens_map_r.values().any(|val| val == token)
|
||||||
|
{
|
||||||
ignored += 1;
|
ignored += 1;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
// If a token is already part of the vocabulary, we mark it as added
|
||||||
let id = if let Some(id) = self.token_to_id(&token.content, model) {
|
let new_id = if let Some(new_id) = self.token_to_id(&token.content, model) {
|
||||||
ignored += 1;
|
new_id
|
||||||
id
|
|
||||||
} else {
|
} else {
|
||||||
let new_id = (model.get_vocab_size() + self.added_tokens_map.len()) as u32;
|
self.added_tokens_map.values().cloned().max().map_or(
|
||||||
self.added_tokens_map.insert(token.content.clone(), new_id);
|
model.get_vocab_size() as u32,
|
||||||
|
|max| {
|
||||||
|
if (max >= model.get_vocab_size() as u32) || model.get_vocab_size() == 0 {
|
||||||
|
max + 1
|
||||||
|
} else {
|
||||||
|
model.get_vocab_size() as u32
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
};
|
||||||
|
// Make sure we modify the previous entry
|
||||||
|
self.added_tokens_map
|
||||||
|
.entry(token.content.clone())
|
||||||
|
.and_modify(|old_id| *old_id = new_id)
|
||||||
|
.or_insert_with(|| new_id);
|
||||||
|
// Update the current revert operation
|
||||||
|
self.added_tokens_map_r
|
||||||
|
.entry(new_id)
|
||||||
|
.and_modify(|t| *t = token.clone())
|
||||||
|
.or_insert_with(|| token.clone());
|
||||||
|
// Make sure to remove previous entry (if the token gets a new id)
|
||||||
|
|
||||||
|
// Finally add the token to the classic set if special
|
||||||
if !self.special_tokens_set.contains(&token.content) {
|
if !self.special_tokens_set.contains(&token.content) {
|
||||||
self.added_tokens.push(token.clone());
|
self.added_tokens.push(token.clone());
|
||||||
}
|
}
|
||||||
|
|
||||||
new_id
|
|
||||||
};
|
|
||||||
|
|
||||||
// Update the current revert operation
|
|
||||||
self.added_tokens_map_r
|
|
||||||
.entry(id)
|
|
||||||
.and_modify(|t| *t = token.clone())
|
|
||||||
.or_insert_with(|| token.clone());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
self.refresh_added_tokens(model, normalizer);
|
self.refresh_added_tokens(model, normalizer);
|
||||||
@ -569,7 +585,9 @@ mod tests {
|
|||||||
),
|
),
|
||||||
1
|
1
|
||||||
);
|
);
|
||||||
assert_eq!(vocab.len(), 1);
|
|
||||||
|
let vocab_len: usize = vocab.len();
|
||||||
|
assert_eq!(vocab_len, 1);
|
||||||
|
|
||||||
// Does not add multiple time the same token
|
// Does not add multiple time the same token
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@ -585,12 +603,15 @@ mod tests {
|
|||||||
);
|
);
|
||||||
assert_eq!(vocab.len(), 2);
|
assert_eq!(vocab.len(), 2);
|
||||||
|
|
||||||
// Does not add tokens already covered by the model
|
// Also adds tokens already covered by the model
|
||||||
|
let added_token = AddedToken::from("test", false);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
vocab.add_tokens(&[AddedToken::from("test", false)], &model, normalizer),
|
vocab.add_tokens(&[added_token.clone()], &model, normalizer),
|
||||||
0
|
1
|
||||||
);
|
);
|
||||||
assert_eq!(vocab.len(), 2);
|
assert_eq!(vocab.len(), 3);
|
||||||
|
|
||||||
|
assert_eq!(vocab.get_added_tokens_decoder()[&0], added_token);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -626,11 +647,47 @@ mod tests {
|
|||||||
// Can add tokens already covered by the model
|
// Can add tokens already covered by the model
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
vocab.add_special_tokens(&[AddedToken::from("test", true)], &model, normalizer),
|
vocab.add_special_tokens(&[AddedToken::from("test", true)], &model, normalizer),
|
||||||
0
|
1
|
||||||
);
|
);
|
||||||
assert_eq!(vocab.len(), 2); // Did not add a new token, since it exist in the original model
|
assert_eq!(vocab.len(), 3); // New token was added
|
||||||
assert!(vocab.is_special_token("test"));
|
assert!(vocab.is_special_token("test"));
|
||||||
assert!(!vocab.added_tokens_map.contains_key("test"));
|
assert_eq!(
|
||||||
|
*vocab.get_added_tokens_decoder(),
|
||||||
|
HashMap::from([
|
||||||
|
(0, AddedToken::from("test", true)),
|
||||||
|
(2, AddedToken::from("added_token_1", true)),
|
||||||
|
(3, AddedToken::from("added_token_2", true)),
|
||||||
|
])
|
||||||
|
);
|
||||||
|
assert!(vocab.added_tokens_map.contains_key("test"));
|
||||||
|
assert!(vocab.added_tokens_map_r.contains_key(&0));
|
||||||
|
|
||||||
|
vocab.add_tokens(
|
||||||
|
&[
|
||||||
|
AddedToken::from("tost", true),
|
||||||
|
AddedToken::from("another_two", false),
|
||||||
|
],
|
||||||
|
&model,
|
||||||
|
normalizer,
|
||||||
|
);
|
||||||
|
assert_eq!(vocab.len(), 5); // New token was added
|
||||||
|
assert_eq!(vocab.get_vocab()["another_two"], 4); // New token was added, but the index is not the length of the vocab
|
||||||
|
|
||||||
|
// Let's add an already added token again
|
||||||
|
assert_eq!(
|
||||||
|
vocab.add_special_tokens(&[AddedToken::from("another_two", true)], &model, normalizer),
|
||||||
|
1
|
||||||
|
);
|
||||||
|
assert_eq!(vocab.len(), 5); // Token was already there
|
||||||
|
assert_eq!(vocab.get_vocab()["another_two"], 4); // Token idx not changed
|
||||||
|
|
||||||
|
// Just checking that we can set the content of the string in rust
|
||||||
|
let mut token: AddedToken = AddedToken::from("Hey", false);
|
||||||
|
token.content = "hey".to_string();
|
||||||
|
assert_eq!(token.content, "hey"); // Token was already there
|
||||||
|
|
||||||
|
token.special = true;
|
||||||
|
assert!(token.special); // Token was already there
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -766,6 +823,8 @@ mod tests {
|
|||||||
let mut vocab = AddedVocabulary::new();
|
let mut vocab = AddedVocabulary::new();
|
||||||
let normalizer = Lowercase;
|
let normalizer = Lowercase;
|
||||||
|
|
||||||
|
assert_eq!(vocab.len(), 0);
|
||||||
|
|
||||||
vocab.add_tokens(
|
vocab.add_tokens(
|
||||||
&[AddedToken::from("<mask>", false).single_word(true)],
|
&[AddedToken::from("<mask>", false).single_word(true)],
|
||||||
&model,
|
&model,
|
||||||
|
@ -659,13 +659,19 @@ where
|
|||||||
final_vocab
|
final_vocab
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Get the added tokens decoder
|
||||||
|
pub fn get_added_tokens_decoder(&self) -> HashMap<u32, AddedToken> {
|
||||||
|
self.added_vocabulary.get_added_tokens_decoder().clone()
|
||||||
|
}
|
||||||
|
|
||||||
/// Get the size of the vocabulary
|
/// Get the size of the vocabulary
|
||||||
pub fn get_vocab_size(&self, with_added_tokens: bool) -> usize {
|
pub fn get_vocab_size(&self, with_added_tokens: bool) -> usize {
|
||||||
self.model.get_vocab_size()
|
// TODO ArthurZ THIS IS WRONG! We need to measure the length of the `set` because
|
||||||
+ if with_added_tokens {
|
// now some tokens can be both in the added_tokens_encoder and in the vocab
|
||||||
self.added_vocabulary.len()
|
if with_added_tokens {
|
||||||
|
self.get_vocab(true).len()
|
||||||
} else {
|
} else {
|
||||||
0
|
self.model.get_vocab_size()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user