mirror of
https://github.com/mii443/tokenizers.git
synced 2025-12-07 21:28:19 +00:00
Python - Make AddedToken pickable
This commit is contained in:
@@ -20,7 +20,7 @@ use tk::tokenizer::{
|
|||||||
PaddingDirection, PaddingParams, PaddingStrategy, TruncationParams, TruncationStrategy,
|
PaddingDirection, PaddingParams, PaddingStrategy, TruncationParams, TruncationStrategy,
|
||||||
};
|
};
|
||||||
|
|
||||||
#[pyclass(dict)]
|
#[pyclass(dict, module = "tokenizers")]
|
||||||
pub struct AddedToken {
|
pub struct AddedToken {
|
||||||
pub token: tk::tokenizer::AddedToken,
|
pub token: tk::tokenizer::AddedToken,
|
||||||
}
|
}
|
||||||
@@ -47,6 +47,40 @@ impl AddedToken {
|
|||||||
Ok(AddedToken { token })
|
Ok(AddedToken { token })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
|
||||||
|
let data = serde_json::to_string(&self.token).map_err(|e| {
|
||||||
|
exceptions::Exception::py_err(format!(
|
||||||
|
"Error while attempting to pickle AddedToken: {}",
|
||||||
|
e.to_string()
|
||||||
|
))
|
||||||
|
})?;
|
||||||
|
Ok(PyBytes::new(py, data.as_bytes()).to_object(py))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
|
||||||
|
match state.extract::<&PyBytes>(py) {
|
||||||
|
Ok(s) => {
|
||||||
|
self.token = serde_json::from_slice(s.as_bytes()).map_err(|e| {
|
||||||
|
exceptions::Exception::py_err(format!(
|
||||||
|
"Error while attempting to unpickle AddedToken: {}",
|
||||||
|
e.to_string()
|
||||||
|
))
|
||||||
|
})?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
Err(e) => Err(e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn __getnewargs__<'p>(&self, py: Python<'p>) -> PyResult<&'p PyTuple> {
|
||||||
|
// We don't really care about the values of `content` & `is_special_token` here because
|
||||||
|
// they will get overriden by `__setstate__`
|
||||||
|
let content: PyObject = "".into_py(py);
|
||||||
|
let is_special_token: PyObject = false.into_py(py);
|
||||||
|
let args = PyTuple::new(py, vec![content, is_special_token]);
|
||||||
|
Ok(args)
|
||||||
|
}
|
||||||
|
|
||||||
#[getter]
|
#[getter]
|
||||||
fn get_content(&self) -> &str {
|
fn get_content(&self) -> &str {
|
||||||
&self.token.content
|
&self.token.content
|
||||||
|
|||||||
@@ -22,6 +22,8 @@ class TestAddedToken:
|
|||||||
assert added_token.rstrip == False
|
assert added_token.rstrip == False
|
||||||
assert added_token.lstrip == False
|
assert added_token.lstrip == False
|
||||||
assert added_token.single_word == False
|
assert added_token.single_word == False
|
||||||
|
assert added_token.normalized == False
|
||||||
|
assert isinstance(pickle.loads(pickle.dumps(added_token)), AddedToken)
|
||||||
|
|
||||||
def test_can_set_rstrip(self):
|
def test_can_set_rstrip(self):
|
||||||
added_token = AddedToken("<mask>", True, rstrip=True)
|
added_token = AddedToken("<mask>", True, rstrip=True)
|
||||||
@@ -41,6 +43,19 @@ class TestAddedToken:
|
|||||||
assert added_token.lstrip == False
|
assert added_token.lstrip == False
|
||||||
assert added_token.single_word == True
|
assert added_token.single_word == True
|
||||||
|
|
||||||
|
def test_can_set_normalized(self):
|
||||||
|
added_token = AddedToken("<mask>", True, normalized=True)
|
||||||
|
assert added_token.rstrip == False
|
||||||
|
assert added_token.lstrip == False
|
||||||
|
assert added_token.single_word == False
|
||||||
|
assert added_token.normalized == True
|
||||||
|
|
||||||
|
def test_second_argument_defines_normalized(self):
|
||||||
|
added_token = AddedToken("<mask>", True)
|
||||||
|
assert added_token.normalized == False
|
||||||
|
added_token = AddedToken("<mask>", False)
|
||||||
|
assert added_token.normalized == True
|
||||||
|
|
||||||
|
|
||||||
class TestTokenizer:
|
class TestTokenizer:
|
||||||
def test_has_expected_type_and_methods(self):
|
def test_has_expected_type_and_methods(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user