Remove Container from PostProcessors, replace with Arc.

* prefix the Python types in Rust with Py.
* remove unsound Container wrappers, replace with Arc.
This commit is contained in:
Sebastian Pütz
2020-07-25 19:34:38 +02:00
committed by Anthony MOI
parent 0f81997351
commit 11e86a16c5
3 changed files with 100 additions and 83 deletions

View File

@ -1,40 +1,84 @@
extern crate tokenizers as tk;
use std::sync::Arc;
use super::utils::Container;
use pyo3::exceptions;
use pyo3::prelude::*;
use pyo3::types::*;
#[pyclass(dict, module = "tokenizers.processors")]
pub struct PostProcessor {
pub processor: Container<dyn tk::tokenizer::PostProcessor>,
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use tk::processors::bert::BertProcessing;
use tk::processors::byte_level::ByteLevel;
use tk::processors::roberta::RobertaProcessing;
use tk::{Encoding, PostProcessor};
use tokenizers as tk;
#[pyclass(dict, module = "tokenizers.processors", name=PostProcessor)]
#[derive(Clone)]
pub struct PyPostProcessor {
pub processor: Arc<dyn PostProcessor>,
}
impl PyPostProcessor {
pub fn new(processor: Arc<dyn PostProcessor>) -> Self {
PyPostProcessor { processor }
}
}
#[typetag::serde]
impl PostProcessor for PyPostProcessor {
fn added_tokens(&self, is_pair: bool) -> usize {
self.processor.added_tokens(is_pair)
}
fn process(
&self,
encoding: Encoding,
pair_encoding: Option<Encoding>,
add_special_tokens: bool,
) -> tk::Result<Encoding> {
self.processor
.process(encoding, pair_encoding, add_special_tokens)
}
}
impl Serialize for PyPostProcessor {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
self.processor.serialize(serializer)
}
}
impl<'de> Deserialize<'de> for PyPostProcessor {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
Ok(PyPostProcessor::new(Arc::deserialize(deserializer)?))
}
}
#[pymethods]
impl PostProcessor {
impl PyPostProcessor {
fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
let data = self
.processor
.execute(|processor| serde_json::to_string(&processor))
.map_err(|e| {
exceptions::Exception::py_err(format!(
"Error while attempting to pickle PostProcessor: {}",
e.to_string()
))
})?;
let data = serde_json::to_string(self.processor.as_ref()).map_err(|e| {
exceptions::Exception::py_err(format!(
"Error while attempting to pickle PostProcessor: {}",
e.to_string()
))
})?;
Ok(PyBytes::new(py, data.as_bytes()).to_object(py))
}
fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
match state.extract::<&PyBytes>(py) {
Ok(s) => {
self.processor =
Container::Owned(serde_json::from_slice(s.as_bytes()).map_err(|e| {
exceptions::Exception::py_err(format!(
"Error while attempting to unpickle PostProcessor: {}",
e.to_string()
))
})?);
self.processor = serde_json::from_slice(s.as_bytes()).map_err(|e| {
exceptions::Exception::py_err(format!(
"Error while attempting to unpickle PostProcessor: {}",
e.to_string()
))
})?;
Ok(())
}
Err(e) => Err(e),
@ -42,23 +86,19 @@ impl PostProcessor {
}
fn num_special_tokens_to_add(&self, is_pair: bool) -> usize {
self.processor.execute(|p| p.added_tokens(is_pair))
self.processor.added_tokens(is_pair)
}
}
#[pyclass(extends=PostProcessor, module = "tokenizers.processors")]
pub struct BertProcessing {}
#[pyclass(extends=PyPostProcessor, module = "tokenizers.processors", name=BertProcessing)]
pub struct PyBertProcessing {}
#[pymethods]
impl BertProcessing {
impl PyBertProcessing {
#[new]
fn new(sep: (String, u32), cls: (String, u32)) -> PyResult<(Self, PostProcessor)> {
fn new(sep: (String, u32), cls: (String, u32)) -> PyResult<(Self, PyPostProcessor)> {
Ok((
BertProcessing {},
PostProcessor {
processor: Container::Owned(Box::new(tk::processors::bert::BertProcessing::new(
sep, cls,
))),
},
PyBertProcessing {},
PyPostProcessor::new(Arc::new(BertProcessing::new(sep, cls))),
))
}
@ -67,10 +107,10 @@ impl BertProcessing {
}
}
#[pyclass(extends=PostProcessor, module = "tokenizers.processors")]
pub struct RobertaProcessing {}
#[pyclass(extends=PyPostProcessor, module = "tokenizers.processors", name=RobertaProcessing)]
pub struct PyRobertaProcessing {}
#[pymethods]
impl RobertaProcessing {
impl PyRobertaProcessing {
#[new]
#[args(trim_offsets = true, add_prefix_space = true)]
fn new(
@ -78,17 +118,11 @@ impl RobertaProcessing {
cls: (String, u32),
trim_offsets: bool,
add_prefix_space: bool,
) -> PyResult<(Self, PostProcessor)> {
Ok((
RobertaProcessing {},
PostProcessor {
processor: Container::Owned(Box::new(
tk::processors::roberta::RobertaProcessing::new(sep, cls)
.trim_offsets(trim_offsets)
.add_prefix_space(add_prefix_space),
)),
},
))
) -> PyResult<(Self, PyPostProcessor)> {
let proc = RobertaProcessing::new(sep, cls)
.trim_offsets(trim_offsets)
.add_prefix_space(add_prefix_space);
Ok((PyRobertaProcessing {}, PyPostProcessor::new(Arc::new(proc))))
}
fn __getnewargs__<'p>(&self, py: Python<'p>) -> PyResult<&'p PyTuple> {
@ -96,14 +130,14 @@ impl RobertaProcessing {
}
}
#[pyclass(extends=PostProcessor, module = "tokenizers.processors")]
pub struct ByteLevel {}
#[pyclass(extends=PyPostProcessor, module = "tokenizers.processors", name=ByteLevel)]
pub struct PyByteLevel {}
#[pymethods]
impl ByteLevel {
impl PyByteLevel {
#[new]
#[args(kwargs = "**")]
fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PostProcessor)> {
let mut byte_level = tk::processors::byte_level::ByteLevel::default();
fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PyPostProcessor)> {
let mut byte_level = ByteLevel::default();
if let Some(kwargs) = kwargs {
for (key, value) in kwargs {
@ -114,11 +148,6 @@ impl ByteLevel {
}
}
}
Ok((
ByteLevel {},
PostProcessor {
processor: Container::Owned(Box::new(byte_level)),
},
))
Ok((PyByteLevel {}, PyPostProcessor::new(Arc::new(byte_level))))
}
}