Python - Make AddedToken pickable

This commit is contained in:
Anthony MOI
2020-06-19 10:34:11 -04:00
parent 63edb95130
commit 898a4a812e
2 changed files with 50 additions and 1 deletions

View File

@@ -20,7 +20,7 @@ use tk::tokenizer::{
PaddingDirection, PaddingParams, PaddingStrategy, TruncationParams, TruncationStrategy,
};
#[pyclass(dict)]
#[pyclass(dict, module = "tokenizers")]
pub struct AddedToken {
pub token: tk::tokenizer::AddedToken,
}
@@ -47,6 +47,40 @@ impl AddedToken {
Ok(AddedToken { token })
}
fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
let data = serde_json::to_string(&self.token).map_err(|e| {
exceptions::Exception::py_err(format!(
"Error while attempting to pickle AddedToken: {}",
e.to_string()
))
})?;
Ok(PyBytes::new(py, data.as_bytes()).to_object(py))
}
fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
match state.extract::<&PyBytes>(py) {
Ok(s) => {
self.token = serde_json::from_slice(s.as_bytes()).map_err(|e| {
exceptions::Exception::py_err(format!(
"Error while attempting to unpickle AddedToken: {}",
e.to_string()
))
})?;
Ok(())
}
Err(e) => Err(e),
}
}
fn __getnewargs__<'p>(&self, py: Python<'p>) -> PyResult<&'p PyTuple> {
// We don't really care about the values of `content` & `is_special_token` here because
// they will get overriden by `__setstate__`
let content: PyObject = "".into_py(py);
let is_special_token: PyObject = false.into_py(py);
let args = PyTuple::new(py, vec![content, is_special_token]);
Ok(args)
}
#[getter]
fn get_content(&self) -> &str {
&self.token.content

View File

@@ -22,6 +22,8 @@ class TestAddedToken:
assert added_token.rstrip == False
assert added_token.lstrip == False
assert added_token.single_word == False
assert added_token.normalized == False
assert isinstance(pickle.loads(pickle.dumps(added_token)), AddedToken)
def test_can_set_rstrip(self):
added_token = AddedToken("<mask>", True, rstrip=True)
@@ -41,6 +43,19 @@ class TestAddedToken:
assert added_token.lstrip == False
assert added_token.single_word == True
def test_can_set_normalized(self):
added_token = AddedToken("<mask>", True, normalized=True)
assert added_token.rstrip == False
assert added_token.lstrip == False
assert added_token.single_word == False
assert added_token.normalized == True
def test_second_argument_defines_normalized(self):
added_token = AddedToken("<mask>", True)
assert added_token.normalized == False
added_token = AddedToken("<mask>", False)
assert added_token.normalized == True
class TestTokenizer:
def test_has_expected_type_and_methods(self):