Python - Make all relevant classes pickable

This commit is contained in:
Anthony MOI
2020-05-18 19:04:00 -04:00
parent 93bb82c657
commit 6a70162d78
12 changed files with 229 additions and 15 deletions

View File

@ -1,12 +1,13 @@
extern crate tokenizers as tk;
use crate::error::PyError;
use pyo3::exceptions;
use pyo3::prelude::*;
use pyo3::types::*;
use pyo3::{PyObjectProtocol, PySequenceProtocol};
use tk::tokenizer::{Offsets, PaddingDirection};
#[pyclass(dict)]
#[pyclass(dict, module = "tokenizers")]
#[repr(transparent)]
pub struct Encoding {
pub encoding: tk::tokenizer::Encoding,
@ -38,6 +39,31 @@ impl PySequenceProtocol for Encoding {
#[pymethods]
impl Encoding {
fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
let data = serde_json::to_string(&self.encoding).map_err(|e| {
exceptions::Exception::py_err(format!(
"Error while attempting to pickle Encoding: {}",
e.to_string()
))
})?;
Ok(PyBytes::new(py, data.as_bytes()).to_object(py))
}
fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
match state.extract::<&PyBytes>(py) {
Ok(s) => {
self.encoding = serde_json::from_slice(s.as_bytes()).map_err(|e| {
exceptions::Exception::py_err(format!(
"Error while attempting to unpickle Encoding: {}",
e.to_string()
))
})?;
Ok(())
}
Err(e) => Err(e),
}
}
#[staticmethod]
#[args(growing_offsets = true)]
fn merge(encodings: Vec<PyRef<Encoding>>, growing_offsets: bool) -> Encoding {