Python - Improve AddedToken interface

This commit is contained in:
Anthony MOI
2020-06-19 17:53:46 -04:00
parent a14cd7b219
commit c02d4e2202
5 changed files with 125 additions and 87 deletions

View File

@ -22,94 +22,133 @@ use tk::tokenizer::{
#[pyclass(dict, module = "tokenizers")]
pub struct AddedToken {
pub token: tk::tokenizer::AddedToken,
pub content: String,
pub is_special_token: bool,
pub single_word: Option<bool>,
pub lstrip: Option<bool>,
pub rstrip: Option<bool>,
pub normalized: Option<bool>,
}
impl AddedToken {
pub fn from<S: Into<String>>(content: S, is_special_token: Option<bool>) -> Self {
Self {
content: content.into(),
is_special_token: is_special_token.unwrap_or(false),
single_word: None,
lstrip: None,
rstrip: None,
normalized: None,
}
}
pub fn get_token(&self) -> tk::tokenizer::AddedToken {
let mut token = tk::AddedToken::from(&self.content, self.is_special_token);
if let Some(sw) = self.single_word {
token = token.single_word(sw);
}
if let Some(ls) = self.lstrip {
token = token.lstrip(ls);
}
if let Some(rs) = self.rstrip {
token = token.rstrip(rs);
}
if let Some(n) = self.normalized {
token = token.normalized(n);
}
token
}
pub fn as_pydict<'py>(&self, py: Python<'py>) -> PyResult<&'py PyDict> {
let dict = PyDict::new(py);
let token = self.get_token();
dict.set_item("content", token.content)?;
dict.set_item("single_word", token.single_word)?;
dict.set_item("lstrip", token.lstrip)?;
dict.set_item("rstrip", token.rstrip)?;
dict.set_item("normalized", token.normalized)?;
Ok(dict)
}
}
#[pymethods]
impl AddedToken {
#[new]
#[args(kwargs = "**")]
fn new(content: &str, is_special_token: bool, kwargs: Option<&PyDict>) -> PyResult<Self> {
let mut token = tk::tokenizer::AddedToken::from(content, is_special_token);
fn new(content: Option<&str>, kwargs: Option<&PyDict>) -> PyResult<Self> {
let mut token = AddedToken::from(content.unwrap_or(""), None);
if let Some(kwargs) = kwargs {
for (key, value) in kwargs {
let key: &str = key.extract()?;
match key {
"single_word" => token = token.single_word(value.extract()?),
"lstrip" => token = token.lstrip(value.extract()?),
"rstrip" => token = token.rstrip(value.extract()?),
"normalized" => token = token.normalized(value.extract()?),
"single_word" => token.single_word = Some(value.extract()?),
"lstrip" => token.lstrip = Some(value.extract()?),
"rstrip" => token.rstrip = Some(value.extract()?),
"normalized" => token.normalized = Some(value.extract()?),
_ => println!("Ignored unknown kwarg option {}", key),
}
}
}
Ok(AddedToken { token })
Ok(token)
}
fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
let data = serde_json::to_string(&self.token).map_err(|e| {
exceptions::Exception::py_err(format!(
"Error while attempting to pickle AddedToken: {}",
e.to_string()
))
})?;
Ok(PyBytes::new(py, data.as_bytes()).to_object(py))
fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<&'py PyDict> {
self.as_pydict(py)
}
fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
match state.extract::<&PyBytes>(py) {
Ok(s) => {
self.token = serde_json::from_slice(s.as_bytes()).map_err(|e| {
exceptions::Exception::py_err(format!(
"Error while attempting to unpickle AddedToken: {}",
e.to_string()
))
})?;
match state.extract::<&PyDict>(py) {
Ok(state) => {
for (key, value) in state {
let key: &str = key.extract()?;
match key {
"single_word" => self.single_word = Some(value.extract()?),
"lstrip" => self.lstrip = Some(value.extract()?),
"rstrip" => self.rstrip = Some(value.extract()?),
"normalized" => self.normalized = Some(value.extract()?),
_ => {}
}
}
Ok(())
}
Err(e) => Err(e),
}
}
fn __getnewargs__<'p>(&self, py: Python<'p>) -> PyResult<&'p PyTuple> {
// We don't really care about the values of `content` & `is_special_token` here because
// they will get overriden by `__setstate__`
let content: PyObject = "".into_py(py);
let is_special_token: PyObject = false.into_py(py);
let args = PyTuple::new(py, vec![content, is_special_token]);
Ok(args)
}
#[getter]
fn get_content(&self) -> &str {
&self.token.content
&self.content
}
#[getter]
fn get_rstrip(&self) -> bool {
self.token.rstrip
self.get_token().rstrip
}
#[getter]
fn get_lstrip(&self) -> bool {
self.token.lstrip
self.get_token().lstrip
}
#[getter]
fn get_single_word(&self) -> bool {
self.token.single_word
self.get_token().single_word
}
#[getter]
fn get_normalized(&self) -> bool {
self.token.normalized
self.get_token().normalized
}
}
#[pyproto]
impl PyObjectProtocol for AddedToken {
fn __str__(&'p self) -> PyResult<&'p str> {
Ok(&self.token.content)
Ok(&self.content)
}
fn __repr__(&self) -> PyResult<String> {
@ -118,13 +157,14 @@ impl PyObjectProtocol for AddedToken {
false => "False",
};
let token = self.get_token();
Ok(format!(
"AddedToken(\"{}\", rstrip={}, lstrip={}, single_word={}, normalized={})",
self.token.content,
bool_to_python(self.token.rstrip),
bool_to_python(self.token.lstrip),
bool_to_python(self.token.single_word),
bool_to_python(self.token.normalized)
self.content,
bool_to_python(token.rstrip),
bool_to_python(token.lstrip),
bool_to_python(token.single_word),
bool_to_python(token.normalized)
))
}
}
@ -583,9 +623,10 @@ impl Tokenizer {
.into_iter()
.map(|token| {
if let Ok(content) = token.extract::<String>() {
Ok(tk::tokenizer::AddedToken::from(content, false))
} else if let Ok(token) = token.extract::<PyRef<AddedToken>>() {
Ok(token.token.clone())
Ok(AddedToken::from(content, Some(false)).get_token())
} else if let Ok(mut token) = token.extract::<PyRefMut<AddedToken>>() {
token.is_special_token = false;
Ok(token.get_token())
} else {
Err(exceptions::Exception::py_err(
"Input must be a List[Union[str, AddedToken]]",
@ -603,8 +644,9 @@ impl Tokenizer {
.map(|token| {
if let Ok(content) = token.extract::<String>() {
Ok(tk::tokenizer::AddedToken::from(content, true))
} else if let Ok(token) = token.extract::<PyRef<AddedToken>>() {
Ok(token.token.clone())
} else if let Ok(mut token) = token.extract::<PyRefMut<AddedToken>>() {
token.is_special_token = true;
Ok(token.get_token())
} else {
Err(exceptions::Exception::py_err(
"Input must be a List[Union[str, AddedToken]]",

View File

@ -36,12 +36,12 @@ impl BpeTrainer {
.into_iter()
.map(|token| {
if let Ok(content) = token.extract::<String>() {
Ok(tk::tokenizer::AddedToken {
content,
..Default::default()
})
} else if let Ok(token) = token.extract::<PyRef<AddedToken>>() {
Ok(token.token.clone())
Ok(AddedToken::from(content, Some(true)).get_token())
} else if let Ok(mut token) =
token.extract::<PyRefMut<AddedToken>>()
{
token.is_special_token = true;
Ok(token.get_token())
} else {
Err(exceptions::Exception::py_err(
"special_tokens must be a List[Union[str, AddedToken]]",
@ -105,12 +105,12 @@ impl WordPieceTrainer {
.into_iter()
.map(|token| {
if let Ok(content) = token.extract::<String>() {
Ok(tk::tokenizer::AddedToken {
content,
..Default::default()
})
} else if let Ok(token) = token.extract::<PyRef<AddedToken>>() {
Ok(token.token.clone())
Ok(AddedToken::from(content, Some(true)).get_token())
} else if let Ok(mut token) =
token.extract::<PyRefMut<AddedToken>>()
{
token.is_special_token = true;
Ok(token.get_token())
} else {
Err(exceptions::Exception::py_err(
"special_tokens must be a List[Union[str, AddedToken]]",

View File

@ -12,49 +12,46 @@ from tokenizers.implementations import BertWordPieceTokenizer
class TestAddedToken:
def test_instantiate_with_content_only(self):
added_token = AddedToken("<mask>", True)
added_token = AddedToken("<mask>")
assert type(added_token) == AddedToken
assert str(added_token) == "<mask>"
assert (
repr(added_token)
== 'AddedToken("<mask>", rstrip=False, lstrip=False, single_word=False, normalized=False)'
== 'AddedToken("<mask>", rstrip=False, lstrip=False, single_word=False, normalized=True)'
)
assert added_token.rstrip == False
assert added_token.lstrip == False
assert added_token.single_word == False
assert added_token.normalized == False
assert added_token.normalized == True
assert isinstance(pickle.loads(pickle.dumps(added_token)), AddedToken)
def test_can_set_rstrip(self):
added_token = AddedToken("<mask>", True, rstrip=True)
added_token = AddedToken("<mask>", rstrip=True)
assert added_token.rstrip == True
assert added_token.lstrip == False
assert added_token.single_word == False
assert added_token.normalized == True
def test_can_set_lstrip(self):
added_token = AddedToken("<mask>", True, lstrip=True)
added_token = AddedToken("<mask>", lstrip=True)
assert added_token.rstrip == False
assert added_token.lstrip == True
assert added_token.single_word == False
assert added_token.normalized == True
def test_can_set_single_world(self):
added_token = AddedToken("<mask>", True, single_word=True)
added_token = AddedToken("<mask>", single_word=True)
assert added_token.rstrip == False
assert added_token.lstrip == False
assert added_token.single_word == True
assert added_token.normalized == True
def test_can_set_normalized(self):
added_token = AddedToken("<mask>", True, normalized=True)
added_token = AddedToken("<mask>", normalized=False)
assert added_token.rstrip == False
assert added_token.lstrip == False
assert added_token.single_word == False
assert added_token.normalized == True
def test_second_argument_defines_normalized(self):
added_token = AddedToken("<mask>", True)
assert added_token.normalized == False
added_token = AddedToken("<mask>", False)
assert added_token.normalized == True
class TestTokenizer:
@ -91,10 +88,12 @@ class TestTokenizer:
added = tokenizer.add_tokens(["my", "name", "is", "john"])
assert added == 4
added = tokenizer.add_tokens(
[AddedToken("the", False), AddedToken("quick", False, rstrip=True)]
)
tokens = [AddedToken("the"), AddedToken("quick", normalized=False), AddedToken()]
assert tokens[0].normalized == True
added = tokenizer.add_tokens(tokens)
assert added == 2
assert tokens[0].normalized == True
assert tokens[1].normalized == False
def test_add_special_tokens(self):
tokenizer = Tokenizer(BPE())
@ -104,10 +103,12 @@ class TestTokenizer:
assert added == 4
# Can add special tokens as `AddedToken`
added = tokenizer.add_special_tokens(
[AddedToken("the", False), AddedToken("quick", False, rstrip=True)]
)
tokens = [AddedToken("the"), AddedToken("quick", normalized=True), AddedToken()]
assert tokens[0].normalized == True
added = tokenizer.add_special_tokens(tokens)
assert added == 2
assert tokens[0].normalized == False
assert tokens[1].normalized == True
def test_encode(self):
tokenizer = Tokenizer(BPE())

View File

@ -201,8 +201,7 @@ class AddedToken:
def __new__(
cls,
content: str,
is_special_token: bool,
content: str = "",
single_word: bool = False,
lstrip: bool = False,
rstrip: bool = False,
@ -214,10 +213,6 @@ class AddedToken:
content: str:
The content of the token
is_special_token: bool:
Whether this token is a special token. This has an impact on the default value for
`normalized` which is False for special tokens, but True for others.
single_word: bool
Whether this token should only match against single words. If True,
this token will never match inside of a word. For example the token `ing` would

View File

@ -219,7 +219,7 @@ impl AddedVocabulary {
normalizer: Option<&dyn Normalizer>,
) -> usize {
for token in tokens {
if !self.special_tokens_set.contains(&token.content) {
if !token.content.is_empty() && !self.special_tokens_set.contains(&token.content) {
self.special_tokens.push(token.to_owned());
self.special_tokens_set.insert(token.content.clone());
}