mirror of
https://github.com/mii443/tokenizers.git
synced 2025-12-07 13:18:31 +00:00
Python - Make AddedToken pickable
This commit is contained in:
@@ -20,7 +20,7 @@ use tk::tokenizer::{
|
||||
PaddingDirection, PaddingParams, PaddingStrategy, TruncationParams, TruncationStrategy,
|
||||
};
|
||||
|
||||
#[pyclass(dict)]
|
||||
#[pyclass(dict, module = "tokenizers")]
|
||||
pub struct AddedToken {
|
||||
pub token: tk::tokenizer::AddedToken,
|
||||
}
|
||||
@@ -47,6 +47,40 @@ impl AddedToken {
|
||||
Ok(AddedToken { token })
|
||||
}
|
||||
|
||||
fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
|
||||
let data = serde_json::to_string(&self.token).map_err(|e| {
|
||||
exceptions::Exception::py_err(format!(
|
||||
"Error while attempting to pickle AddedToken: {}",
|
||||
e.to_string()
|
||||
))
|
||||
})?;
|
||||
Ok(PyBytes::new(py, data.as_bytes()).to_object(py))
|
||||
}
|
||||
|
||||
fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
|
||||
match state.extract::<&PyBytes>(py) {
|
||||
Ok(s) => {
|
||||
self.token = serde_json::from_slice(s.as_bytes()).map_err(|e| {
|
||||
exceptions::Exception::py_err(format!(
|
||||
"Error while attempting to unpickle AddedToken: {}",
|
||||
e.to_string()
|
||||
))
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
|
||||
fn __getnewargs__<'p>(&self, py: Python<'p>) -> PyResult<&'p PyTuple> {
|
||||
// We don't really care about the values of `content` & `is_special_token` here because
|
||||
// they will get overriden by `__setstate__`
|
||||
let content: PyObject = "".into_py(py);
|
||||
let is_special_token: PyObject = false.into_py(py);
|
||||
let args = PyTuple::new(py, vec![content, is_special_token]);
|
||||
Ok(args)
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_content(&self) -> &str {
|
||||
&self.token.content
|
||||
|
||||
@@ -22,6 +22,8 @@ class TestAddedToken:
|
||||
assert added_token.rstrip == False
|
||||
assert added_token.lstrip == False
|
||||
assert added_token.single_word == False
|
||||
assert added_token.normalized == False
|
||||
assert isinstance(pickle.loads(pickle.dumps(added_token)), AddedToken)
|
||||
|
||||
def test_can_set_rstrip(self):
|
||||
added_token = AddedToken("<mask>", True, rstrip=True)
|
||||
@@ -41,6 +43,19 @@ class TestAddedToken:
|
||||
assert added_token.lstrip == False
|
||||
assert added_token.single_word == True
|
||||
|
||||
def test_can_set_normalized(self):
|
||||
added_token = AddedToken("<mask>", True, normalized=True)
|
||||
assert added_token.rstrip == False
|
||||
assert added_token.lstrip == False
|
||||
assert added_token.single_word == False
|
||||
assert added_token.normalized == True
|
||||
|
||||
def test_second_argument_defines_normalized(self):
|
||||
added_token = AddedToken("<mask>", True)
|
||||
assert added_token.normalized == False
|
||||
added_token = AddedToken("<mask>", False)
|
||||
assert added_token.normalized == True
|
||||
|
||||
|
||||
class TestTokenizer:
|
||||
def test_has_expected_type_and_methods(self):
|
||||
|
||||
Reference in New Issue
Block a user