Remove Container from PostProcessors, replace with Arc.

* prefix the Python types in Rust with Py.
* remove unsound Container wrappers, replace with Arc.
This commit is contained in:
Sebastian Pütz
2020-07-25 19:34:38 +02:00
committed by Anthony MOI
parent 0f81997351
commit 11e86a16c5
3 changed files with 100 additions and 83 deletions

View File

@@ -82,10 +82,10 @@ fn decoders(_py: Python, m: &PyModule) -> PyResult<()> {
/// Processors Module /// Processors Module
#[pymodule] #[pymodule]
fn processors(_py: Python, m: &PyModule) -> PyResult<()> { fn processors(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<processors::PostProcessor>()?; m.add_class::<processors::PyPostProcessor>()?;
m.add_class::<processors::BertProcessing>()?; m.add_class::<processors::PyBertProcessing>()?;
m.add_class::<processors::RobertaProcessing>()?; m.add_class::<processors::PyRobertaProcessing>()?;
m.add_class::<processors::ByteLevel>()?; m.add_class::<processors::PyByteLevel>()?;
Ok(()) Ok(())
} }

View File

@@ -1,40 +1,84 @@
extern crate tokenizers as tk; use std::sync::Arc;
use super::utils::Container;
use pyo3::exceptions; use pyo3::exceptions;
use pyo3::prelude::*; use pyo3::prelude::*;
use pyo3::types::*; use pyo3::types::*;
#[pyclass(dict, module = "tokenizers.processors")] use serde::{Deserialize, Deserializer, Serialize, Serializer};
pub struct PostProcessor { use tk::processors::bert::BertProcessing;
pub processor: Container<dyn tk::tokenizer::PostProcessor>, use tk::processors::byte_level::ByteLevel;
use tk::processors::roberta::RobertaProcessing;
use tk::{Encoding, PostProcessor};
use tokenizers as tk;
#[pyclass(dict, module = "tokenizers.processors", name=PostProcessor)]
#[derive(Clone)]
pub struct PyPostProcessor {
pub processor: Arc<dyn PostProcessor>,
}
impl PyPostProcessor {
pub fn new(processor: Arc<dyn PostProcessor>) -> Self {
PyPostProcessor { processor }
}
}
#[typetag::serde]
impl PostProcessor for PyPostProcessor {
fn added_tokens(&self, is_pair: bool) -> usize {
self.processor.added_tokens(is_pair)
}
fn process(
&self,
encoding: Encoding,
pair_encoding: Option<Encoding>,
add_special_tokens: bool,
) -> tk::Result<Encoding> {
self.processor
.process(encoding, pair_encoding, add_special_tokens)
}
}
impl Serialize for PyPostProcessor {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
self.processor.serialize(serializer)
}
}
impl<'de> Deserialize<'de> for PyPostProcessor {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
Ok(PyPostProcessor::new(Arc::deserialize(deserializer)?))
}
} }
#[pymethods] #[pymethods]
impl PostProcessor { impl PyPostProcessor {
fn __getstate__(&self, py: Python) -> PyResult<PyObject> { fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
let data = self let data = serde_json::to_string(self.processor.as_ref()).map_err(|e| {
.processor exceptions::Exception::py_err(format!(
.execute(|processor| serde_json::to_string(&processor)) "Error while attempting to pickle PostProcessor: {}",
.map_err(|e| { e.to_string()
exceptions::Exception::py_err(format!( ))
"Error while attempting to pickle PostProcessor: {}", })?;
e.to_string()
))
})?;
Ok(PyBytes::new(py, data.as_bytes()).to_object(py)) Ok(PyBytes::new(py, data.as_bytes()).to_object(py))
} }
fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
match state.extract::<&PyBytes>(py) { match state.extract::<&PyBytes>(py) {
Ok(s) => { Ok(s) => {
self.processor = self.processor = serde_json::from_slice(s.as_bytes()).map_err(|e| {
Container::Owned(serde_json::from_slice(s.as_bytes()).map_err(|e| { exceptions::Exception::py_err(format!(
exceptions::Exception::py_err(format!( "Error while attempting to unpickle PostProcessor: {}",
"Error while attempting to unpickle PostProcessor: {}", e.to_string()
e.to_string() ))
)) })?;
})?);
Ok(()) Ok(())
} }
Err(e) => Err(e), Err(e) => Err(e),
@@ -42,23 +86,19 @@ impl PostProcessor {
} }
fn num_special_tokens_to_add(&self, is_pair: bool) -> usize { fn num_special_tokens_to_add(&self, is_pair: bool) -> usize {
self.processor.execute(|p| p.added_tokens(is_pair)) self.processor.added_tokens(is_pair)
} }
} }
#[pyclass(extends=PostProcessor, module = "tokenizers.processors")] #[pyclass(extends=PyPostProcessor, module = "tokenizers.processors", name=BertProcessing)]
pub struct BertProcessing {} pub struct PyBertProcessing {}
#[pymethods] #[pymethods]
impl BertProcessing { impl PyBertProcessing {
#[new] #[new]
fn new(sep: (String, u32), cls: (String, u32)) -> PyResult<(Self, PostProcessor)> { fn new(sep: (String, u32), cls: (String, u32)) -> PyResult<(Self, PyPostProcessor)> {
Ok(( Ok((
BertProcessing {}, PyBertProcessing {},
PostProcessor { PyPostProcessor::new(Arc::new(BertProcessing::new(sep, cls))),
processor: Container::Owned(Box::new(tk::processors::bert::BertProcessing::new(
sep, cls,
))),
},
)) ))
} }
@@ -67,10 +107,10 @@ impl BertProcessing {
} }
} }
#[pyclass(extends=PostProcessor, module = "tokenizers.processors")] #[pyclass(extends=PyPostProcessor, module = "tokenizers.processors", name=RobertaProcessing)]
pub struct RobertaProcessing {} pub struct PyRobertaProcessing {}
#[pymethods] #[pymethods]
impl RobertaProcessing { impl PyRobertaProcessing {
#[new] #[new]
#[args(trim_offsets = true, add_prefix_space = true)] #[args(trim_offsets = true, add_prefix_space = true)]
fn new( fn new(
@@ -78,17 +118,11 @@ impl RobertaProcessing {
cls: (String, u32), cls: (String, u32),
trim_offsets: bool, trim_offsets: bool,
add_prefix_space: bool, add_prefix_space: bool,
) -> PyResult<(Self, PostProcessor)> { ) -> PyResult<(Self, PyPostProcessor)> {
Ok(( let proc = RobertaProcessing::new(sep, cls)
RobertaProcessing {}, .trim_offsets(trim_offsets)
PostProcessor { .add_prefix_space(add_prefix_space);
processor: Container::Owned(Box::new( Ok((PyRobertaProcessing {}, PyPostProcessor::new(Arc::new(proc))))
tk::processors::roberta::RobertaProcessing::new(sep, cls)
.trim_offsets(trim_offsets)
.add_prefix_space(add_prefix_space),
)),
},
))
} }
fn __getnewargs__<'p>(&self, py: Python<'p>) -> PyResult<&'p PyTuple> { fn __getnewargs__<'p>(&self, py: Python<'p>) -> PyResult<&'p PyTuple> {
@@ -96,14 +130,14 @@ impl RobertaProcessing {
} }
} }
#[pyclass(extends=PostProcessor, module = "tokenizers.processors")] #[pyclass(extends=PyPostProcessor, module = "tokenizers.processors", name=ByteLevel)]
pub struct ByteLevel {} pub struct PyByteLevel {}
#[pymethods] #[pymethods]
impl ByteLevel { impl PyByteLevel {
#[new] #[new]
#[args(kwargs = "**")] #[args(kwargs = "**")]
fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PostProcessor)> { fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PyPostProcessor)> {
let mut byte_level = tk::processors::byte_level::ByteLevel::default(); let mut byte_level = ByteLevel::default();
if let Some(kwargs) = kwargs { if let Some(kwargs) = kwargs {
for (key, value) in kwargs { for (key, value) in kwargs {
@@ -114,11 +148,6 @@ impl ByteLevel {
} }
} }
} }
Ok(( Ok((PyByteLevel {}, PyPostProcessor::new(Arc::new(byte_level))))
ByteLevel {},
PostProcessor {
processor: Container::Owned(Box::new(byte_level)),
},
))
} }
} }

View File

@@ -7,7 +7,7 @@ use pyo3::types::*;
use pyo3::PyObjectProtocol; use pyo3::PyObjectProtocol;
use tk::models::bpe::BPE; use tk::models::bpe::BPE;
use tk::tokenizer::{ use tk::tokenizer::{
PaddingDirection, PaddingParams, PaddingStrategy, Tokenizer, TruncationParams, PaddingDirection, PaddingParams, PaddingStrategy, PostProcessor, Tokenizer, TruncationParams,
TruncationStrategy, TruncationStrategy,
}; };
use tokenizers as tk; use tokenizers as tk;
@@ -18,9 +18,9 @@ use super::error::{PyError, ToPyResult};
use super::models::PyModel; use super::models::PyModel;
use super::normalizers::PyNormalizer; use super::normalizers::PyNormalizer;
use super::pre_tokenizers::PyPreTokenizer; use super::pre_tokenizers::PyPreTokenizer;
use super::processors::PostProcessor;
use super::trainers::PyTrainer; use super::trainers::PyTrainer;
use super::utils::Container; use super::utils::Container;
use crate::processors::PyPostProcessor;
#[pyclass(dict, module = "tokenizers", name=AddedToken)] #[pyclass(dict, module = "tokenizers", name=AddedToken)]
pub struct PyAddedToken { pub struct PyAddedToken {
@@ -268,9 +268,9 @@ impl From<PreTokenizedEncodeInput> for tk::tokenizer::EncodeInput {
} }
} }
type TokenizerImpl = Tokenizer<PyModel, PyNormalizer, PyPreTokenizer>; type TokenizerImpl = Tokenizer<PyModel, PyNormalizer, PyPreTokenizer, PyPostProcessor>;
#[pyclass(dict, module = "tokenizers")] #[pyclass(dict, module = "tokenizers", name=Tokenizer)]
pub struct PyTokenizer { pub struct PyTokenizer {
tokenizer: TokenizerImpl, tokenizer: TokenizerImpl,
} }
@@ -362,7 +362,7 @@ impl PyTokenizer {
Ok(self Ok(self
.tokenizer .tokenizer
.get_post_processor() .get_post_processor()
.map_or(0, |p| p.as_ref().added_tokens(is_pair))) .map_or(0, |p| p.added_tokens(is_pair)))
} }
#[args(with_added_tokens = true)] #[args(with_added_tokens = true)]
@@ -727,25 +727,13 @@ impl PyTokenizer {
} }
#[getter] #[getter]
fn get_post_processor(&self) -> PyResult<Option<PostProcessor>> { fn get_post_processor(&self) -> Option<PyPostProcessor> {
Ok(self self.tokenizer.get_post_processor().cloned()
.tokenizer
.get_post_processor()
.map(|processor| PostProcessor {
processor: Container::from_ref(processor),
}))
} }
#[setter] #[setter]
fn set_post_processor(&mut self, mut processor: PyRefMut<PostProcessor>) -> PyResult<()> { fn set_post_processor(&mut self, processor: PyRef<PyPostProcessor>) {
if let Some(processor) = processor.processor.to_pointer() { self.tokenizer.with_post_processor(processor.clone());
self.tokenizer.with_post_processor(processor);
Ok(())
} else {
Err(exceptions::Exception::py_err(
"The Processor is already being used in another Tokenizer",
))
}
} }
#[getter] #[getter]