Ensure serialization works in all expected ways.

This commit is contained in:
Sebastian Puetz
2020-08-01 13:34:18 +02:00
committed by Anthony MOI
parent aaf8e932b1
commit 16f75d9efc
39 changed files with 1303 additions and 615 deletions

View File

@ -4,26 +4,44 @@ use pyo3::exceptions;
use pyo3::prelude::*;
use pyo3::types::*;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use serde::{Deserialize, Serialize};
use tk::processors::bert::BertProcessing;
use tk::processors::byte_level::ByteLevel;
use tk::processors::roberta::RobertaProcessing;
use tk::processors::PostProcessorWrapper;
use tk::{Encoding, PostProcessor};
use tokenizers as tk;
#[pyclass(dict, module = "tokenizers.processors", name=PostProcessor)]
#[derive(Clone)]
#[derive(Clone, Deserialize, Serialize)]
pub struct PyPostProcessor {
pub processor: Arc<dyn PostProcessor>,
#[serde(flatten)]
pub processor: Arc<PostProcessorWrapper>,
}
impl PyPostProcessor {
pub fn new(processor: Arc<dyn PostProcessor>) -> Self {
pub fn new(processor: Arc<PostProcessorWrapper>) -> Self {
PyPostProcessor { processor }
}
pub(crate) fn get_as_subtype(&self) -> PyResult<PyObject> {
let base = self.clone();
let gil = Python::acquire_gil();
let py = gil.python();
match self.processor.as_ref() {
PostProcessorWrapper::ByteLevel(_) => {
Py::new(py, (PyByteLevel {}, base)).map(Into::into)
}
PostProcessorWrapper::Bert(_) => {
Py::new(py, (PyBertProcessing {}, base)).map(Into::into)
}
PostProcessorWrapper::Roberta(_) => {
Py::new(py, (PyRobertaProcessing {}, base)).map(Into::into)
}
}
}
}
#[typetag::serde]
impl PostProcessor for PyPostProcessor {
fn added_tokens(&self, is_pair: bool) -> usize {
self.processor.added_tokens(is_pair)
@ -40,24 +58,6 @@ impl PostProcessor for PyPostProcessor {
}
}
impl Serialize for PyPostProcessor {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
self.processor.serialize(serializer)
}
}
impl<'de> Deserialize<'de> for PyPostProcessor {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
Ok(PyPostProcessor::new(Arc::deserialize(deserializer)?))
}
}
#[pymethods]
impl PyPostProcessor {
fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
@ -98,7 +98,7 @@ impl PyBertProcessing {
fn new(sep: (String, u32), cls: (String, u32)) -> PyResult<(Self, PyPostProcessor)> {
Ok((
PyBertProcessing {},
PyPostProcessor::new(Arc::new(BertProcessing::new(sep, cls))),
PyPostProcessor::new(Arc::new(BertProcessing::new(sep, cls).into())),
))
}
@ -122,7 +122,10 @@ impl PyRobertaProcessing {
let proc = RobertaProcessing::new(sep, cls)
.trim_offsets(trim_offsets)
.add_prefix_space(add_prefix_space);
Ok((PyRobertaProcessing {}, PyPostProcessor::new(Arc::new(proc))))
Ok((
PyRobertaProcessing {},
PyPostProcessor::new(Arc::new(proc.into())),
))
}
fn __getnewargs__<'p>(&self, py: Python<'p>) -> PyResult<&'p PyTuple> {
@ -148,6 +151,58 @@ impl PyByteLevel {
}
}
}
Ok((PyByteLevel {}, PyPostProcessor::new(Arc::new(byte_level))))
Ok((
PyByteLevel {},
PyPostProcessor::new(Arc::new(byte_level.into())),
))
}
}
#[cfg(test)]
mod test {
use std::sync::Arc;
use pyo3::{AsPyRef, Python};
use tk::processors::bert::BertProcessing;
use tk::processors::PostProcessorWrapper;
use crate::processors::PyPostProcessor;
#[test]
fn get_subtype() {
let py_proc = PyPostProcessor::new(Arc::new(
BertProcessing::new(("SEP".into(), 0), ("CLS".into(), 1)).into(),
));
let py_bert = py_proc.get_as_subtype().unwrap();
let gil = Python::acquire_gil();
assert_eq!(
"tokenizers.processors.BertProcessing",
py_bert.as_ref(gil.python()).get_type().name()
);
}
#[test]
fn serialize() {
let rs_processing = BertProcessing::new(("SEP".into(), 0), ("CLS".into(), 1));
let rs_wrapper: PostProcessorWrapper = rs_processing.clone().into();
let rs_processing_ser = serde_json::to_string(&rs_processing).unwrap();
let rs_wrapper_ser = serde_json::to_string(&rs_wrapper).unwrap();
let py_processing = PyPostProcessor::new(Arc::new(rs_wrapper.clone()));
let py_ser = serde_json::to_string(&py_processing).unwrap();
assert_eq!(py_ser, rs_processing_ser);
assert_eq!(py_ser, rs_wrapper_ser);
let py_processing: PyPostProcessor = serde_json::from_str(&rs_processing_ser).unwrap();
match py_processing.processor.as_ref() {
PostProcessorWrapper::Bert(_) => (),
_ => panic!("Expected Bert postprocessor."),
}
let py_processing: PyPostProcessor = serde_json::from_str(&rs_wrapper_ser).unwrap();
match py_processing.processor.as_ref() {
PostProcessorWrapper::Bert(_) => (),
_ => panic!("Expected Bert postprocessor."),
}
}
}