From 898a4a812e83c43959fb1f7346d64127ce49a3ba Mon Sep 17 00:00:00 2001 From: Anthony MOI Date: Fri, 19 Jun 2020 10:34:11 -0400 Subject: [PATCH] Python - Make AddedToken pickable --- bindings/python/src/tokenizer.rs | 36 ++++++++++++++++++- .../python/tests/bindings/test_tokenizer.py | 15 ++++++++ 2 files changed, 50 insertions(+), 1 deletion(-) diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index 7ef49714..7596c1c7 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -20,7 +20,7 @@ use tk::tokenizer::{ PaddingDirection, PaddingParams, PaddingStrategy, TruncationParams, TruncationStrategy, }; -#[pyclass(dict)] +#[pyclass(dict, module = "tokenizers")] pub struct AddedToken { pub token: tk::tokenizer::AddedToken, } @@ -47,6 +47,40 @@ impl AddedToken { Ok(AddedToken { token }) } + fn __getstate__(&self, py: Python) -> PyResult { + let data = serde_json::to_string(&self.token).map_err(|e| { + exceptions::Exception::py_err(format!( + "Error while attempting to pickle AddedToken: {}", + e.to_string() + )) + })?; + Ok(PyBytes::new(py, data.as_bytes()).to_object(py)) + } + + fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { + match state.extract::<&PyBytes>(py) { + Ok(s) => { + self.token = serde_json::from_slice(s.as_bytes()).map_err(|e| { + exceptions::Exception::py_err(format!( + "Error while attempting to unpickle AddedToken: {}", + e.to_string() + )) + })?; + Ok(()) + } + Err(e) => Err(e), + } + } + + fn __getnewargs__<'p>(&self, py: Python<'p>) -> PyResult<&'p PyTuple> { + // We don't really care about the values of `content` & `is_special_token` here because + // they will get overriden by `__setstate__` + let content: PyObject = "".into_py(py); + let is_special_token: PyObject = false.into_py(py); + let args = PyTuple::new(py, vec![content, is_special_token]); + Ok(args) + } + #[getter] fn get_content(&self) -> &str { &self.token.content diff --git a/bindings/python/tests/bindings/test_tokenizer.py b/bindings/python/tests/bindings/test_tokenizer.py index 548b70aa..2c4c97cf 100644 --- a/bindings/python/tests/bindings/test_tokenizer.py +++ b/bindings/python/tests/bindings/test_tokenizer.py @@ -22,6 +22,8 @@ class TestAddedToken: assert added_token.rstrip == False assert added_token.lstrip == False assert added_token.single_word == False + assert added_token.normalized == False + assert isinstance(pickle.loads(pickle.dumps(added_token)), AddedToken) def test_can_set_rstrip(self): added_token = AddedToken("", True, rstrip=True) @@ -41,6 +43,19 @@ class TestAddedToken: assert added_token.lstrip == False assert added_token.single_word == True + def test_can_set_normalized(self): + added_token = AddedToken("", True, normalized=True) + assert added_token.rstrip == False + assert added_token.lstrip == False + assert added_token.single_word == False + assert added_token.normalized == True + + def test_second_argument_defines_normalized(self): + added_token = AddedToken("", True) + assert added_token.normalized == False + added_token = AddedToken("", False) + assert added_token.normalized == True + class TestTokenizer: def test_has_expected_type_and_methods(self):