mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Python - Improve AddedToken interface
This commit is contained in:
@ -22,94 +22,133 @@ use tk::tokenizer::{
|
||||
|
||||
#[pyclass(dict, module = "tokenizers")]
|
||||
pub struct AddedToken {
|
||||
pub token: tk::tokenizer::AddedToken,
|
||||
pub content: String,
|
||||
pub is_special_token: bool,
|
||||
pub single_word: Option<bool>,
|
||||
pub lstrip: Option<bool>,
|
||||
pub rstrip: Option<bool>,
|
||||
pub normalized: Option<bool>,
|
||||
}
|
||||
impl AddedToken {
|
||||
pub fn from<S: Into<String>>(content: S, is_special_token: Option<bool>) -> Self {
|
||||
Self {
|
||||
content: content.into(),
|
||||
is_special_token: is_special_token.unwrap_or(false),
|
||||
single_word: None,
|
||||
lstrip: None,
|
||||
rstrip: None,
|
||||
normalized: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_token(&self) -> tk::tokenizer::AddedToken {
|
||||
let mut token = tk::AddedToken::from(&self.content, self.is_special_token);
|
||||
|
||||
if let Some(sw) = self.single_word {
|
||||
token = token.single_word(sw);
|
||||
}
|
||||
if let Some(ls) = self.lstrip {
|
||||
token = token.lstrip(ls);
|
||||
}
|
||||
if let Some(rs) = self.rstrip {
|
||||
token = token.rstrip(rs);
|
||||
}
|
||||
if let Some(n) = self.normalized {
|
||||
token = token.normalized(n);
|
||||
}
|
||||
|
||||
token
|
||||
}
|
||||
|
||||
pub fn as_pydict<'py>(&self, py: Python<'py>) -> PyResult<&'py PyDict> {
|
||||
let dict = PyDict::new(py);
|
||||
let token = self.get_token();
|
||||
|
||||
dict.set_item("content", token.content)?;
|
||||
dict.set_item("single_word", token.single_word)?;
|
||||
dict.set_item("lstrip", token.lstrip)?;
|
||||
dict.set_item("rstrip", token.rstrip)?;
|
||||
dict.set_item("normalized", token.normalized)?;
|
||||
|
||||
Ok(dict)
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl AddedToken {
|
||||
#[new]
|
||||
#[args(kwargs = "**")]
|
||||
fn new(content: &str, is_special_token: bool, kwargs: Option<&PyDict>) -> PyResult<Self> {
|
||||
let mut token = tk::tokenizer::AddedToken::from(content, is_special_token);
|
||||
fn new(content: Option<&str>, kwargs: Option<&PyDict>) -> PyResult<Self> {
|
||||
let mut token = AddedToken::from(content.unwrap_or(""), None);
|
||||
|
||||
if let Some(kwargs) = kwargs {
|
||||
for (key, value) in kwargs {
|
||||
let key: &str = key.extract()?;
|
||||
match key {
|
||||
"single_word" => token = token.single_word(value.extract()?),
|
||||
"lstrip" => token = token.lstrip(value.extract()?),
|
||||
"rstrip" => token = token.rstrip(value.extract()?),
|
||||
"normalized" => token = token.normalized(value.extract()?),
|
||||
"single_word" => token.single_word = Some(value.extract()?),
|
||||
"lstrip" => token.lstrip = Some(value.extract()?),
|
||||
"rstrip" => token.rstrip = Some(value.extract()?),
|
||||
"normalized" => token.normalized = Some(value.extract()?),
|
||||
_ => println!("Ignored unknown kwarg option {}", key),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(AddedToken { token })
|
||||
Ok(token)
|
||||
}
|
||||
|
||||
fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
|
||||
let data = serde_json::to_string(&self.token).map_err(|e| {
|
||||
exceptions::Exception::py_err(format!(
|
||||
"Error while attempting to pickle AddedToken: {}",
|
||||
e.to_string()
|
||||
))
|
||||
})?;
|
||||
Ok(PyBytes::new(py, data.as_bytes()).to_object(py))
|
||||
fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<&'py PyDict> {
|
||||
self.as_pydict(py)
|
||||
}
|
||||
|
||||
fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
|
||||
match state.extract::<&PyBytes>(py) {
|
||||
Ok(s) => {
|
||||
self.token = serde_json::from_slice(s.as_bytes()).map_err(|e| {
|
||||
exceptions::Exception::py_err(format!(
|
||||
"Error while attempting to unpickle AddedToken: {}",
|
||||
e.to_string()
|
||||
))
|
||||
})?;
|
||||
match state.extract::<&PyDict>(py) {
|
||||
Ok(state) => {
|
||||
for (key, value) in state {
|
||||
let key: &str = key.extract()?;
|
||||
match key {
|
||||
"single_word" => self.single_word = Some(value.extract()?),
|
||||
"lstrip" => self.lstrip = Some(value.extract()?),
|
||||
"rstrip" => self.rstrip = Some(value.extract()?),
|
||||
"normalized" => self.normalized = Some(value.extract()?),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
|
||||
fn __getnewargs__<'p>(&self, py: Python<'p>) -> PyResult<&'p PyTuple> {
|
||||
// We don't really care about the values of `content` & `is_special_token` here because
|
||||
// they will get overriden by `__setstate__`
|
||||
let content: PyObject = "".into_py(py);
|
||||
let is_special_token: PyObject = false.into_py(py);
|
||||
let args = PyTuple::new(py, vec![content, is_special_token]);
|
||||
Ok(args)
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_content(&self) -> &str {
|
||||
&self.token.content
|
||||
&self.content
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_rstrip(&self) -> bool {
|
||||
self.token.rstrip
|
||||
self.get_token().rstrip
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_lstrip(&self) -> bool {
|
||||
self.token.lstrip
|
||||
self.get_token().lstrip
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_single_word(&self) -> bool {
|
||||
self.token.single_word
|
||||
self.get_token().single_word
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_normalized(&self) -> bool {
|
||||
self.token.normalized
|
||||
self.get_token().normalized
|
||||
}
|
||||
}
|
||||
#[pyproto]
|
||||
impl PyObjectProtocol for AddedToken {
|
||||
fn __str__(&'p self) -> PyResult<&'p str> {
|
||||
Ok(&self.token.content)
|
||||
Ok(&self.content)
|
||||
}
|
||||
|
||||
fn __repr__(&self) -> PyResult<String> {
|
||||
@ -118,13 +157,14 @@ impl PyObjectProtocol for AddedToken {
|
||||
false => "False",
|
||||
};
|
||||
|
||||
let token = self.get_token();
|
||||
Ok(format!(
|
||||
"AddedToken(\"{}\", rstrip={}, lstrip={}, single_word={}, normalized={})",
|
||||
self.token.content,
|
||||
bool_to_python(self.token.rstrip),
|
||||
bool_to_python(self.token.lstrip),
|
||||
bool_to_python(self.token.single_word),
|
||||
bool_to_python(self.token.normalized)
|
||||
self.content,
|
||||
bool_to_python(token.rstrip),
|
||||
bool_to_python(token.lstrip),
|
||||
bool_to_python(token.single_word),
|
||||
bool_to_python(token.normalized)
|
||||
))
|
||||
}
|
||||
}
|
||||
@ -583,9 +623,10 @@ impl Tokenizer {
|
||||
.into_iter()
|
||||
.map(|token| {
|
||||
if let Ok(content) = token.extract::<String>() {
|
||||
Ok(tk::tokenizer::AddedToken::from(content, false))
|
||||
} else if let Ok(token) = token.extract::<PyRef<AddedToken>>() {
|
||||
Ok(token.token.clone())
|
||||
Ok(AddedToken::from(content, Some(false)).get_token())
|
||||
} else if let Ok(mut token) = token.extract::<PyRefMut<AddedToken>>() {
|
||||
token.is_special_token = false;
|
||||
Ok(token.get_token())
|
||||
} else {
|
||||
Err(exceptions::Exception::py_err(
|
||||
"Input must be a List[Union[str, AddedToken]]",
|
||||
@ -603,8 +644,9 @@ impl Tokenizer {
|
||||
.map(|token| {
|
||||
if let Ok(content) = token.extract::<String>() {
|
||||
Ok(tk::tokenizer::AddedToken::from(content, true))
|
||||
} else if let Ok(token) = token.extract::<PyRef<AddedToken>>() {
|
||||
Ok(token.token.clone())
|
||||
} else if let Ok(mut token) = token.extract::<PyRefMut<AddedToken>>() {
|
||||
token.is_special_token = true;
|
||||
Ok(token.get_token())
|
||||
} else {
|
||||
Err(exceptions::Exception::py_err(
|
||||
"Input must be a List[Union[str, AddedToken]]",
|
||||
|
@ -36,12 +36,12 @@ impl BpeTrainer {
|
||||
.into_iter()
|
||||
.map(|token| {
|
||||
if let Ok(content) = token.extract::<String>() {
|
||||
Ok(tk::tokenizer::AddedToken {
|
||||
content,
|
||||
..Default::default()
|
||||
})
|
||||
} else if let Ok(token) = token.extract::<PyRef<AddedToken>>() {
|
||||
Ok(token.token.clone())
|
||||
Ok(AddedToken::from(content, Some(true)).get_token())
|
||||
} else if let Ok(mut token) =
|
||||
token.extract::<PyRefMut<AddedToken>>()
|
||||
{
|
||||
token.is_special_token = true;
|
||||
Ok(token.get_token())
|
||||
} else {
|
||||
Err(exceptions::Exception::py_err(
|
||||
"special_tokens must be a List[Union[str, AddedToken]]",
|
||||
@ -105,12 +105,12 @@ impl WordPieceTrainer {
|
||||
.into_iter()
|
||||
.map(|token| {
|
||||
if let Ok(content) = token.extract::<String>() {
|
||||
Ok(tk::tokenizer::AddedToken {
|
||||
content,
|
||||
..Default::default()
|
||||
})
|
||||
} else if let Ok(token) = token.extract::<PyRef<AddedToken>>() {
|
||||
Ok(token.token.clone())
|
||||
Ok(AddedToken::from(content, Some(true)).get_token())
|
||||
} else if let Ok(mut token) =
|
||||
token.extract::<PyRefMut<AddedToken>>()
|
||||
{
|
||||
token.is_special_token = true;
|
||||
Ok(token.get_token())
|
||||
} else {
|
||||
Err(exceptions::Exception::py_err(
|
||||
"special_tokens must be a List[Union[str, AddedToken]]",
|
||||
|
@ -12,49 +12,46 @@ from tokenizers.implementations import BertWordPieceTokenizer
|
||||
|
||||
class TestAddedToken:
|
||||
def test_instantiate_with_content_only(self):
|
||||
added_token = AddedToken("<mask>", True)
|
||||
added_token = AddedToken("<mask>")
|
||||
assert type(added_token) == AddedToken
|
||||
assert str(added_token) == "<mask>"
|
||||
assert (
|
||||
repr(added_token)
|
||||
== 'AddedToken("<mask>", rstrip=False, lstrip=False, single_word=False, normalized=False)'
|
||||
== 'AddedToken("<mask>", rstrip=False, lstrip=False, single_word=False, normalized=True)'
|
||||
)
|
||||
assert added_token.rstrip == False
|
||||
assert added_token.lstrip == False
|
||||
assert added_token.single_word == False
|
||||
assert added_token.normalized == False
|
||||
assert added_token.normalized == True
|
||||
assert isinstance(pickle.loads(pickle.dumps(added_token)), AddedToken)
|
||||
|
||||
def test_can_set_rstrip(self):
|
||||
added_token = AddedToken("<mask>", True, rstrip=True)
|
||||
added_token = AddedToken("<mask>", rstrip=True)
|
||||
assert added_token.rstrip == True
|
||||
assert added_token.lstrip == False
|
||||
assert added_token.single_word == False
|
||||
assert added_token.normalized == True
|
||||
|
||||
def test_can_set_lstrip(self):
|
||||
added_token = AddedToken("<mask>", True, lstrip=True)
|
||||
added_token = AddedToken("<mask>", lstrip=True)
|
||||
assert added_token.rstrip == False
|
||||
assert added_token.lstrip == True
|
||||
assert added_token.single_word == False
|
||||
assert added_token.normalized == True
|
||||
|
||||
def test_can_set_single_world(self):
|
||||
added_token = AddedToken("<mask>", True, single_word=True)
|
||||
added_token = AddedToken("<mask>", single_word=True)
|
||||
assert added_token.rstrip == False
|
||||
assert added_token.lstrip == False
|
||||
assert added_token.single_word == True
|
||||
assert added_token.normalized == True
|
||||
|
||||
def test_can_set_normalized(self):
|
||||
added_token = AddedToken("<mask>", True, normalized=True)
|
||||
added_token = AddedToken("<mask>", normalized=False)
|
||||
assert added_token.rstrip == False
|
||||
assert added_token.lstrip == False
|
||||
assert added_token.single_word == False
|
||||
assert added_token.normalized == True
|
||||
|
||||
def test_second_argument_defines_normalized(self):
|
||||
added_token = AddedToken("<mask>", True)
|
||||
assert added_token.normalized == False
|
||||
added_token = AddedToken("<mask>", False)
|
||||
assert added_token.normalized == True
|
||||
|
||||
|
||||
class TestTokenizer:
|
||||
@ -91,10 +88,12 @@ class TestTokenizer:
|
||||
added = tokenizer.add_tokens(["my", "name", "is", "john"])
|
||||
assert added == 4
|
||||
|
||||
added = tokenizer.add_tokens(
|
||||
[AddedToken("the", False), AddedToken("quick", False, rstrip=True)]
|
||||
)
|
||||
tokens = [AddedToken("the"), AddedToken("quick", normalized=False), AddedToken()]
|
||||
assert tokens[0].normalized == True
|
||||
added = tokenizer.add_tokens(tokens)
|
||||
assert added == 2
|
||||
assert tokens[0].normalized == True
|
||||
assert tokens[1].normalized == False
|
||||
|
||||
def test_add_special_tokens(self):
|
||||
tokenizer = Tokenizer(BPE())
|
||||
@ -104,10 +103,12 @@ class TestTokenizer:
|
||||
assert added == 4
|
||||
|
||||
# Can add special tokens as `AddedToken`
|
||||
added = tokenizer.add_special_tokens(
|
||||
[AddedToken("the", False), AddedToken("quick", False, rstrip=True)]
|
||||
)
|
||||
tokens = [AddedToken("the"), AddedToken("quick", normalized=True), AddedToken()]
|
||||
assert tokens[0].normalized == True
|
||||
added = tokenizer.add_special_tokens(tokens)
|
||||
assert added == 2
|
||||
assert tokens[0].normalized == False
|
||||
assert tokens[1].normalized == True
|
||||
|
||||
def test_encode(self):
|
||||
tokenizer = Tokenizer(BPE())
|
||||
|
@ -201,8 +201,7 @@ class AddedToken:
|
||||
|
||||
def __new__(
|
||||
cls,
|
||||
content: str,
|
||||
is_special_token: bool,
|
||||
content: str = "",
|
||||
single_word: bool = False,
|
||||
lstrip: bool = False,
|
||||
rstrip: bool = False,
|
||||
@ -214,10 +213,6 @@ class AddedToken:
|
||||
content: str:
|
||||
The content of the token
|
||||
|
||||
is_special_token: bool:
|
||||
Whether this token is a special token. This has an impact on the default value for
|
||||
`normalized` which is False for special tokens, but True for others.
|
||||
|
||||
single_word: bool
|
||||
Whether this token should only match against single words. If True,
|
||||
this token will never match inside of a word. For example the token `ing` would
|
||||
|
@ -219,7 +219,7 @@ impl AddedVocabulary {
|
||||
normalizer: Option<&dyn Normalizer>,
|
||||
) -> usize {
|
||||
for token in tokens {
|
||||
if !self.special_tokens_set.contains(&token.content) {
|
||||
if !token.content.is_empty() && !self.special_tokens_set.contains(&token.content) {
|
||||
self.special_tokens.push(token.to_owned());
|
||||
self.special_tokens_set.insert(token.content.clone());
|
||||
}
|
||||
|
Reference in New Issue
Block a user