Python - Improve AddedToken interface

This commit is contained in:
Anthony MOI
2020-06-19 17:53:46 -04:00
parent a14cd7b219
commit c02d4e2202
5 changed files with 125 additions and 87 deletions

View File

@ -22,94 +22,133 @@ use tk::tokenizer::{
#[pyclass(dict, module = "tokenizers")] #[pyclass(dict, module = "tokenizers")]
pub struct AddedToken { pub struct AddedToken {
pub token: tk::tokenizer::AddedToken, pub content: String,
pub is_special_token: bool,
pub single_word: Option<bool>,
pub lstrip: Option<bool>,
pub rstrip: Option<bool>,
pub normalized: Option<bool>,
} }
impl AddedToken {
pub fn from<S: Into<String>>(content: S, is_special_token: Option<bool>) -> Self {
Self {
content: content.into(),
is_special_token: is_special_token.unwrap_or(false),
single_word: None,
lstrip: None,
rstrip: None,
normalized: None,
}
}
pub fn get_token(&self) -> tk::tokenizer::AddedToken {
let mut token = tk::AddedToken::from(&self.content, self.is_special_token);
if let Some(sw) = self.single_word {
token = token.single_word(sw);
}
if let Some(ls) = self.lstrip {
token = token.lstrip(ls);
}
if let Some(rs) = self.rstrip {
token = token.rstrip(rs);
}
if let Some(n) = self.normalized {
token = token.normalized(n);
}
token
}
pub fn as_pydict<'py>(&self, py: Python<'py>) -> PyResult<&'py PyDict> {
let dict = PyDict::new(py);
let token = self.get_token();
dict.set_item("content", token.content)?;
dict.set_item("single_word", token.single_word)?;
dict.set_item("lstrip", token.lstrip)?;
dict.set_item("rstrip", token.rstrip)?;
dict.set_item("normalized", token.normalized)?;
Ok(dict)
}
}
#[pymethods] #[pymethods]
impl AddedToken { impl AddedToken {
#[new] #[new]
#[args(kwargs = "**")] #[args(kwargs = "**")]
fn new(content: &str, is_special_token: bool, kwargs: Option<&PyDict>) -> PyResult<Self> { fn new(content: Option<&str>, kwargs: Option<&PyDict>) -> PyResult<Self> {
let mut token = tk::tokenizer::AddedToken::from(content, is_special_token); let mut token = AddedToken::from(content.unwrap_or(""), None);
if let Some(kwargs) = kwargs { if let Some(kwargs) = kwargs {
for (key, value) in kwargs { for (key, value) in kwargs {
let key: &str = key.extract()?; let key: &str = key.extract()?;
match key { match key {
"single_word" => token = token.single_word(value.extract()?), "single_word" => token.single_word = Some(value.extract()?),
"lstrip" => token = token.lstrip(value.extract()?), "lstrip" => token.lstrip = Some(value.extract()?),
"rstrip" => token = token.rstrip(value.extract()?), "rstrip" => token.rstrip = Some(value.extract()?),
"normalized" => token = token.normalized(value.extract()?), "normalized" => token.normalized = Some(value.extract()?),
_ => println!("Ignored unknown kwarg option {}", key), _ => println!("Ignored unknown kwarg option {}", key),
} }
} }
} }
Ok(AddedToken { token }) Ok(token)
} }
fn __getstate__(&self, py: Python) -> PyResult<PyObject> { fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult<&'py PyDict> {
let data = serde_json::to_string(&self.token).map_err(|e| { self.as_pydict(py)
exceptions::Exception::py_err(format!(
"Error while attempting to pickle AddedToken: {}",
e.to_string()
))
})?;
Ok(PyBytes::new(py, data.as_bytes()).to_object(py))
} }
fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
match state.extract::<&PyBytes>(py) { match state.extract::<&PyDict>(py) {
Ok(s) => { Ok(state) => {
self.token = serde_json::from_slice(s.as_bytes()).map_err(|e| { for (key, value) in state {
exceptions::Exception::py_err(format!( let key: &str = key.extract()?;
"Error while attempting to unpickle AddedToken: {}", match key {
e.to_string() "single_word" => self.single_word = Some(value.extract()?),
)) "lstrip" => self.lstrip = Some(value.extract()?),
})?; "rstrip" => self.rstrip = Some(value.extract()?),
"normalized" => self.normalized = Some(value.extract()?),
_ => {}
}
}
Ok(()) Ok(())
} }
Err(e) => Err(e), Err(e) => Err(e),
} }
} }
fn __getnewargs__<'p>(&self, py: Python<'p>) -> PyResult<&'p PyTuple> {
// We don't really care about the values of `content` & `is_special_token` here because
// they will get overriden by `__setstate__`
let content: PyObject = "".into_py(py);
let is_special_token: PyObject = false.into_py(py);
let args = PyTuple::new(py, vec![content, is_special_token]);
Ok(args)
}
#[getter] #[getter]
fn get_content(&self) -> &str { fn get_content(&self) -> &str {
&self.token.content &self.content
} }
#[getter] #[getter]
fn get_rstrip(&self) -> bool { fn get_rstrip(&self) -> bool {
self.token.rstrip self.get_token().rstrip
} }
#[getter] #[getter]
fn get_lstrip(&self) -> bool { fn get_lstrip(&self) -> bool {
self.token.lstrip self.get_token().lstrip
} }
#[getter] #[getter]
fn get_single_word(&self) -> bool { fn get_single_word(&self) -> bool {
self.token.single_word self.get_token().single_word
} }
#[getter] #[getter]
fn get_normalized(&self) -> bool { fn get_normalized(&self) -> bool {
self.token.normalized self.get_token().normalized
} }
} }
#[pyproto] #[pyproto]
impl PyObjectProtocol for AddedToken { impl PyObjectProtocol for AddedToken {
fn __str__(&'p self) -> PyResult<&'p str> { fn __str__(&'p self) -> PyResult<&'p str> {
Ok(&self.token.content) Ok(&self.content)
} }
fn __repr__(&self) -> PyResult<String> { fn __repr__(&self) -> PyResult<String> {
@ -118,13 +157,14 @@ impl PyObjectProtocol for AddedToken {
false => "False", false => "False",
}; };
let token = self.get_token();
Ok(format!( Ok(format!(
"AddedToken(\"{}\", rstrip={}, lstrip={}, single_word={}, normalized={})", "AddedToken(\"{}\", rstrip={}, lstrip={}, single_word={}, normalized={})",
self.token.content, self.content,
bool_to_python(self.token.rstrip), bool_to_python(token.rstrip),
bool_to_python(self.token.lstrip), bool_to_python(token.lstrip),
bool_to_python(self.token.single_word), bool_to_python(token.single_word),
bool_to_python(self.token.normalized) bool_to_python(token.normalized)
)) ))
} }
} }
@ -583,9 +623,10 @@ impl Tokenizer {
.into_iter() .into_iter()
.map(|token| { .map(|token| {
if let Ok(content) = token.extract::<String>() { if let Ok(content) = token.extract::<String>() {
Ok(tk::tokenizer::AddedToken::from(content, false)) Ok(AddedToken::from(content, Some(false)).get_token())
} else if let Ok(token) = token.extract::<PyRef<AddedToken>>() { } else if let Ok(mut token) = token.extract::<PyRefMut<AddedToken>>() {
Ok(token.token.clone()) token.is_special_token = false;
Ok(token.get_token())
} else { } else {
Err(exceptions::Exception::py_err( Err(exceptions::Exception::py_err(
"Input must be a List[Union[str, AddedToken]]", "Input must be a List[Union[str, AddedToken]]",
@ -603,8 +644,9 @@ impl Tokenizer {
.map(|token| { .map(|token| {
if let Ok(content) = token.extract::<String>() { if let Ok(content) = token.extract::<String>() {
Ok(tk::tokenizer::AddedToken::from(content, true)) Ok(tk::tokenizer::AddedToken::from(content, true))
} else if let Ok(token) = token.extract::<PyRef<AddedToken>>() { } else if let Ok(mut token) = token.extract::<PyRefMut<AddedToken>>() {
Ok(token.token.clone()) token.is_special_token = true;
Ok(token.get_token())
} else { } else {
Err(exceptions::Exception::py_err( Err(exceptions::Exception::py_err(
"Input must be a List[Union[str, AddedToken]]", "Input must be a List[Union[str, AddedToken]]",

View File

@ -36,12 +36,12 @@ impl BpeTrainer {
.into_iter() .into_iter()
.map(|token| { .map(|token| {
if let Ok(content) = token.extract::<String>() { if let Ok(content) = token.extract::<String>() {
Ok(tk::tokenizer::AddedToken { Ok(AddedToken::from(content, Some(true)).get_token())
content, } else if let Ok(mut token) =
..Default::default() token.extract::<PyRefMut<AddedToken>>()
}) {
} else if let Ok(token) = token.extract::<PyRef<AddedToken>>() { token.is_special_token = true;
Ok(token.token.clone()) Ok(token.get_token())
} else { } else {
Err(exceptions::Exception::py_err( Err(exceptions::Exception::py_err(
"special_tokens must be a List[Union[str, AddedToken]]", "special_tokens must be a List[Union[str, AddedToken]]",
@ -105,12 +105,12 @@ impl WordPieceTrainer {
.into_iter() .into_iter()
.map(|token| { .map(|token| {
if let Ok(content) = token.extract::<String>() { if let Ok(content) = token.extract::<String>() {
Ok(tk::tokenizer::AddedToken { Ok(AddedToken::from(content, Some(true)).get_token())
content, } else if let Ok(mut token) =
..Default::default() token.extract::<PyRefMut<AddedToken>>()
}) {
} else if let Ok(token) = token.extract::<PyRef<AddedToken>>() { token.is_special_token = true;
Ok(token.token.clone()) Ok(token.get_token())
} else { } else {
Err(exceptions::Exception::py_err( Err(exceptions::Exception::py_err(
"special_tokens must be a List[Union[str, AddedToken]]", "special_tokens must be a List[Union[str, AddedToken]]",

View File

@ -12,49 +12,46 @@ from tokenizers.implementations import BertWordPieceTokenizer
class TestAddedToken: class TestAddedToken:
def test_instantiate_with_content_only(self): def test_instantiate_with_content_only(self):
added_token = AddedToken("<mask>", True) added_token = AddedToken("<mask>")
assert type(added_token) == AddedToken assert type(added_token) == AddedToken
assert str(added_token) == "<mask>" assert str(added_token) == "<mask>"
assert ( assert (
repr(added_token) repr(added_token)
== 'AddedToken("<mask>", rstrip=False, lstrip=False, single_word=False, normalized=False)' == 'AddedToken("<mask>", rstrip=False, lstrip=False, single_word=False, normalized=True)'
) )
assert added_token.rstrip == False assert added_token.rstrip == False
assert added_token.lstrip == False assert added_token.lstrip == False
assert added_token.single_word == False assert added_token.single_word == False
assert added_token.normalized == False assert added_token.normalized == True
assert isinstance(pickle.loads(pickle.dumps(added_token)), AddedToken) assert isinstance(pickle.loads(pickle.dumps(added_token)), AddedToken)
def test_can_set_rstrip(self): def test_can_set_rstrip(self):
added_token = AddedToken("<mask>", True, rstrip=True) added_token = AddedToken("<mask>", rstrip=True)
assert added_token.rstrip == True assert added_token.rstrip == True
assert added_token.lstrip == False assert added_token.lstrip == False
assert added_token.single_word == False assert added_token.single_word == False
assert added_token.normalized == True
def test_can_set_lstrip(self): def test_can_set_lstrip(self):
added_token = AddedToken("<mask>", True, lstrip=True) added_token = AddedToken("<mask>", lstrip=True)
assert added_token.rstrip == False assert added_token.rstrip == False
assert added_token.lstrip == True assert added_token.lstrip == True
assert added_token.single_word == False assert added_token.single_word == False
assert added_token.normalized == True
def test_can_set_single_world(self): def test_can_set_single_world(self):
added_token = AddedToken("<mask>", True, single_word=True) added_token = AddedToken("<mask>", single_word=True)
assert added_token.rstrip == False assert added_token.rstrip == False
assert added_token.lstrip == False assert added_token.lstrip == False
assert added_token.single_word == True assert added_token.single_word == True
assert added_token.normalized == True
def test_can_set_normalized(self): def test_can_set_normalized(self):
added_token = AddedToken("<mask>", True, normalized=True) added_token = AddedToken("<mask>", normalized=False)
assert added_token.rstrip == False assert added_token.rstrip == False
assert added_token.lstrip == False assert added_token.lstrip == False
assert added_token.single_word == False assert added_token.single_word == False
assert added_token.normalized == True
def test_second_argument_defines_normalized(self):
added_token = AddedToken("<mask>", True)
assert added_token.normalized == False assert added_token.normalized == False
added_token = AddedToken("<mask>", False)
assert added_token.normalized == True
class TestTokenizer: class TestTokenizer:
@ -91,10 +88,12 @@ class TestTokenizer:
added = tokenizer.add_tokens(["my", "name", "is", "john"]) added = tokenizer.add_tokens(["my", "name", "is", "john"])
assert added == 4 assert added == 4
added = tokenizer.add_tokens( tokens = [AddedToken("the"), AddedToken("quick", normalized=False), AddedToken()]
[AddedToken("the", False), AddedToken("quick", False, rstrip=True)] assert tokens[0].normalized == True
) added = tokenizer.add_tokens(tokens)
assert added == 2 assert added == 2
assert tokens[0].normalized == True
assert tokens[1].normalized == False
def test_add_special_tokens(self): def test_add_special_tokens(self):
tokenizer = Tokenizer(BPE()) tokenizer = Tokenizer(BPE())
@ -104,10 +103,12 @@ class TestTokenizer:
assert added == 4 assert added == 4
# Can add special tokens as `AddedToken` # Can add special tokens as `AddedToken`
added = tokenizer.add_special_tokens( tokens = [AddedToken("the"), AddedToken("quick", normalized=True), AddedToken()]
[AddedToken("the", False), AddedToken("quick", False, rstrip=True)] assert tokens[0].normalized == True
) added = tokenizer.add_special_tokens(tokens)
assert added == 2 assert added == 2
assert tokens[0].normalized == False
assert tokens[1].normalized == True
def test_encode(self): def test_encode(self):
tokenizer = Tokenizer(BPE()) tokenizer = Tokenizer(BPE())

View File

@ -201,8 +201,7 @@ class AddedToken:
def __new__( def __new__(
cls, cls,
content: str, content: str = "",
is_special_token: bool,
single_word: bool = False, single_word: bool = False,
lstrip: bool = False, lstrip: bool = False,
rstrip: bool = False, rstrip: bool = False,
@ -214,10 +213,6 @@ class AddedToken:
content: str: content: str:
The content of the token The content of the token
is_special_token: bool:
Whether this token is a special token. This has an impact on the default value for
`normalized` which is False for special tokens, but True for others.
single_word: bool single_word: bool
Whether this token should only match against single words. If True, Whether this token should only match against single words. If True,
this token will never match inside of a word. For example the token `ing` would this token will never match inside of a word. For example the token `ing` would

View File

@ -219,7 +219,7 @@ impl AddedVocabulary {
normalizer: Option<&dyn Normalizer>, normalizer: Option<&dyn Normalizer>,
) -> usize { ) -> usize {
for token in tokens { for token in tokens {
if !self.special_tokens_set.contains(&token.content) { if !token.content.is_empty() && !self.special_tokens_set.contains(&token.content) {
self.special_tokens.push(token.to_owned()); self.special_tokens.push(token.to_owned());
self.special_tokens_set.insert(token.content.clone()); self.special_tokens_set.insert(token.content.clone());
} }