mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-23 00:35:35 +00:00
Ensure serialization works in all expected ways.
This commit is contained in:
committed by
Anthony MOI
parent
aaf8e932b1
commit
16f75d9efc
@ -4,26 +4,44 @@ use pyo3::exceptions;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::*;
|
||||
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tk::processors::bert::BertProcessing;
|
||||
use tk::processors::byte_level::ByteLevel;
|
||||
use tk::processors::roberta::RobertaProcessing;
|
||||
use tk::processors::PostProcessorWrapper;
|
||||
use tk::{Encoding, PostProcessor};
|
||||
use tokenizers as tk;
|
||||
|
||||
#[pyclass(dict, module = "tokenizers.processors", name=PostProcessor)]
|
||||
#[derive(Clone)]
|
||||
#[derive(Clone, Deserialize, Serialize)]
|
||||
pub struct PyPostProcessor {
|
||||
pub processor: Arc<dyn PostProcessor>,
|
||||
#[serde(flatten)]
|
||||
pub processor: Arc<PostProcessorWrapper>,
|
||||
}
|
||||
|
||||
impl PyPostProcessor {
|
||||
pub fn new(processor: Arc<dyn PostProcessor>) -> Self {
|
||||
pub fn new(processor: Arc<PostProcessorWrapper>) -> Self {
|
||||
PyPostProcessor { processor }
|
||||
}
|
||||
|
||||
pub(crate) fn get_as_subtype(&self) -> PyResult<PyObject> {
|
||||
let base = self.clone();
|
||||
let gil = Python::acquire_gil();
|
||||
let py = gil.python();
|
||||
match self.processor.as_ref() {
|
||||
PostProcessorWrapper::ByteLevel(_) => {
|
||||
Py::new(py, (PyByteLevel {}, base)).map(Into::into)
|
||||
}
|
||||
PostProcessorWrapper::Bert(_) => {
|
||||
Py::new(py, (PyBertProcessing {}, base)).map(Into::into)
|
||||
}
|
||||
PostProcessorWrapper::Roberta(_) => {
|
||||
Py::new(py, (PyRobertaProcessing {}, base)).map(Into::into)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[typetag::serde]
|
||||
impl PostProcessor for PyPostProcessor {
|
||||
fn added_tokens(&self, is_pair: bool) -> usize {
|
||||
self.processor.added_tokens(is_pair)
|
||||
@ -40,24 +58,6 @@ impl PostProcessor for PyPostProcessor {
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for PyPostProcessor {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
self.processor.serialize(serializer)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for PyPostProcessor {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
Ok(PyPostProcessor::new(Arc::deserialize(deserializer)?))
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl PyPostProcessor {
|
||||
fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
|
||||
@ -98,7 +98,7 @@ impl PyBertProcessing {
|
||||
fn new(sep: (String, u32), cls: (String, u32)) -> PyResult<(Self, PyPostProcessor)> {
|
||||
Ok((
|
||||
PyBertProcessing {},
|
||||
PyPostProcessor::new(Arc::new(BertProcessing::new(sep, cls))),
|
||||
PyPostProcessor::new(Arc::new(BertProcessing::new(sep, cls).into())),
|
||||
))
|
||||
}
|
||||
|
||||
@ -122,7 +122,10 @@ impl PyRobertaProcessing {
|
||||
let proc = RobertaProcessing::new(sep, cls)
|
||||
.trim_offsets(trim_offsets)
|
||||
.add_prefix_space(add_prefix_space);
|
||||
Ok((PyRobertaProcessing {}, PyPostProcessor::new(Arc::new(proc))))
|
||||
Ok((
|
||||
PyRobertaProcessing {},
|
||||
PyPostProcessor::new(Arc::new(proc.into())),
|
||||
))
|
||||
}
|
||||
|
||||
fn __getnewargs__<'p>(&self, py: Python<'p>) -> PyResult<&'p PyTuple> {
|
||||
@ -148,6 +151,58 @@ impl PyByteLevel {
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok((PyByteLevel {}, PyPostProcessor::new(Arc::new(byte_level))))
|
||||
Ok((
|
||||
PyByteLevel {},
|
||||
PyPostProcessor::new(Arc::new(byte_level.into())),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use std::sync::Arc;
|
||||
|
||||
use pyo3::{AsPyRef, Python};
|
||||
use tk::processors::bert::BertProcessing;
|
||||
use tk::processors::PostProcessorWrapper;
|
||||
|
||||
use crate::processors::PyPostProcessor;
|
||||
|
||||
#[test]
|
||||
fn get_subtype() {
|
||||
let py_proc = PyPostProcessor::new(Arc::new(
|
||||
BertProcessing::new(("SEP".into(), 0), ("CLS".into(), 1)).into(),
|
||||
));
|
||||
let py_bert = py_proc.get_as_subtype().unwrap();
|
||||
let gil = Python::acquire_gil();
|
||||
assert_eq!(
|
||||
"tokenizers.processors.BertProcessing",
|
||||
py_bert.as_ref(gil.python()).get_type().name()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serialize() {
|
||||
let rs_processing = BertProcessing::new(("SEP".into(), 0), ("CLS".into(), 1));
|
||||
let rs_wrapper: PostProcessorWrapper = rs_processing.clone().into();
|
||||
let rs_processing_ser = serde_json::to_string(&rs_processing).unwrap();
|
||||
let rs_wrapper_ser = serde_json::to_string(&rs_wrapper).unwrap();
|
||||
|
||||
let py_processing = PyPostProcessor::new(Arc::new(rs_wrapper.clone()));
|
||||
let py_ser = serde_json::to_string(&py_processing).unwrap();
|
||||
assert_eq!(py_ser, rs_processing_ser);
|
||||
assert_eq!(py_ser, rs_wrapper_ser);
|
||||
|
||||
let py_processing: PyPostProcessor = serde_json::from_str(&rs_processing_ser).unwrap();
|
||||
match py_processing.processor.as_ref() {
|
||||
PostProcessorWrapper::Bert(_) => (),
|
||||
_ => panic!("Expected Bert postprocessor."),
|
||||
}
|
||||
|
||||
let py_processing: PyPostProcessor = serde_json::from_str(&rs_wrapper_ser).unwrap();
|
||||
match py_processing.processor.as_ref() {
|
||||
PostProcessorWrapper::Bert(_) => (),
|
||||
_ => panic!("Expected Bert postprocessor."),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user