mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-23 00:35:35 +00:00
Remove Container from PostProcessors, replace with Arc.
* prefix the Python types in Rust with Py. * remove unsound Container wrappers, replace with Arc.
This commit is contained in:
committed by
Anthony MOI
parent
0f81997351
commit
11e86a16c5
@ -1,40 +1,84 @@
|
||||
extern crate tokenizers as tk;
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::utils::Container;
|
||||
use pyo3::exceptions;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::*;
|
||||
|
||||
#[pyclass(dict, module = "tokenizers.processors")]
|
||||
pub struct PostProcessor {
|
||||
pub processor: Container<dyn tk::tokenizer::PostProcessor>,
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
use tk::processors::bert::BertProcessing;
|
||||
use tk::processors::byte_level::ByteLevel;
|
||||
use tk::processors::roberta::RobertaProcessing;
|
||||
use tk::{Encoding, PostProcessor};
|
||||
use tokenizers as tk;
|
||||
|
||||
#[pyclass(dict, module = "tokenizers.processors", name=PostProcessor)]
|
||||
#[derive(Clone)]
|
||||
pub struct PyPostProcessor {
|
||||
pub processor: Arc<dyn PostProcessor>,
|
||||
}
|
||||
|
||||
impl PyPostProcessor {
|
||||
pub fn new(processor: Arc<dyn PostProcessor>) -> Self {
|
||||
PyPostProcessor { processor }
|
||||
}
|
||||
}
|
||||
|
||||
#[typetag::serde]
|
||||
impl PostProcessor for PyPostProcessor {
|
||||
fn added_tokens(&self, is_pair: bool) -> usize {
|
||||
self.processor.added_tokens(is_pair)
|
||||
}
|
||||
|
||||
fn process(
|
||||
&self,
|
||||
encoding: Encoding,
|
||||
pair_encoding: Option<Encoding>,
|
||||
add_special_tokens: bool,
|
||||
) -> tk::Result<Encoding> {
|
||||
self.processor
|
||||
.process(encoding, pair_encoding, add_special_tokens)
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for PyPostProcessor {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
self.processor.serialize(serializer)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for PyPostProcessor {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
Ok(PyPostProcessor::new(Arc::deserialize(deserializer)?))
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl PostProcessor {
|
||||
impl PyPostProcessor {
|
||||
fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
|
||||
let data = self
|
||||
.processor
|
||||
.execute(|processor| serde_json::to_string(&processor))
|
||||
.map_err(|e| {
|
||||
exceptions::Exception::py_err(format!(
|
||||
"Error while attempting to pickle PostProcessor: {}",
|
||||
e.to_string()
|
||||
))
|
||||
})?;
|
||||
let data = serde_json::to_string(self.processor.as_ref()).map_err(|e| {
|
||||
exceptions::Exception::py_err(format!(
|
||||
"Error while attempting to pickle PostProcessor: {}",
|
||||
e.to_string()
|
||||
))
|
||||
})?;
|
||||
Ok(PyBytes::new(py, data.as_bytes()).to_object(py))
|
||||
}
|
||||
|
||||
fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
|
||||
match state.extract::<&PyBytes>(py) {
|
||||
Ok(s) => {
|
||||
self.processor =
|
||||
Container::Owned(serde_json::from_slice(s.as_bytes()).map_err(|e| {
|
||||
exceptions::Exception::py_err(format!(
|
||||
"Error while attempting to unpickle PostProcessor: {}",
|
||||
e.to_string()
|
||||
))
|
||||
})?);
|
||||
self.processor = serde_json::from_slice(s.as_bytes()).map_err(|e| {
|
||||
exceptions::Exception::py_err(format!(
|
||||
"Error while attempting to unpickle PostProcessor: {}",
|
||||
e.to_string()
|
||||
))
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
Err(e) => Err(e),
|
||||
@ -42,23 +86,19 @@ impl PostProcessor {
|
||||
}
|
||||
|
||||
fn num_special_tokens_to_add(&self, is_pair: bool) -> usize {
|
||||
self.processor.execute(|p| p.added_tokens(is_pair))
|
||||
self.processor.added_tokens(is_pair)
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass(extends=PostProcessor, module = "tokenizers.processors")]
|
||||
pub struct BertProcessing {}
|
||||
#[pyclass(extends=PyPostProcessor, module = "tokenizers.processors", name=BertProcessing)]
|
||||
pub struct PyBertProcessing {}
|
||||
#[pymethods]
|
||||
impl BertProcessing {
|
||||
impl PyBertProcessing {
|
||||
#[new]
|
||||
fn new(sep: (String, u32), cls: (String, u32)) -> PyResult<(Self, PostProcessor)> {
|
||||
fn new(sep: (String, u32), cls: (String, u32)) -> PyResult<(Self, PyPostProcessor)> {
|
||||
Ok((
|
||||
BertProcessing {},
|
||||
PostProcessor {
|
||||
processor: Container::Owned(Box::new(tk::processors::bert::BertProcessing::new(
|
||||
sep, cls,
|
||||
))),
|
||||
},
|
||||
PyBertProcessing {},
|
||||
PyPostProcessor::new(Arc::new(BertProcessing::new(sep, cls))),
|
||||
))
|
||||
}
|
||||
|
||||
@ -67,10 +107,10 @@ impl BertProcessing {
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass(extends=PostProcessor, module = "tokenizers.processors")]
|
||||
pub struct RobertaProcessing {}
|
||||
#[pyclass(extends=PyPostProcessor, module = "tokenizers.processors", name=RobertaProcessing)]
|
||||
pub struct PyRobertaProcessing {}
|
||||
#[pymethods]
|
||||
impl RobertaProcessing {
|
||||
impl PyRobertaProcessing {
|
||||
#[new]
|
||||
#[args(trim_offsets = true, add_prefix_space = true)]
|
||||
fn new(
|
||||
@ -78,17 +118,11 @@ impl RobertaProcessing {
|
||||
cls: (String, u32),
|
||||
trim_offsets: bool,
|
||||
add_prefix_space: bool,
|
||||
) -> PyResult<(Self, PostProcessor)> {
|
||||
Ok((
|
||||
RobertaProcessing {},
|
||||
PostProcessor {
|
||||
processor: Container::Owned(Box::new(
|
||||
tk::processors::roberta::RobertaProcessing::new(sep, cls)
|
||||
.trim_offsets(trim_offsets)
|
||||
.add_prefix_space(add_prefix_space),
|
||||
)),
|
||||
},
|
||||
))
|
||||
) -> PyResult<(Self, PyPostProcessor)> {
|
||||
let proc = RobertaProcessing::new(sep, cls)
|
||||
.trim_offsets(trim_offsets)
|
||||
.add_prefix_space(add_prefix_space);
|
||||
Ok((PyRobertaProcessing {}, PyPostProcessor::new(Arc::new(proc))))
|
||||
}
|
||||
|
||||
fn __getnewargs__<'p>(&self, py: Python<'p>) -> PyResult<&'p PyTuple> {
|
||||
@ -96,14 +130,14 @@ impl RobertaProcessing {
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass(extends=PostProcessor, module = "tokenizers.processors")]
|
||||
pub struct ByteLevel {}
|
||||
#[pyclass(extends=PyPostProcessor, module = "tokenizers.processors", name=ByteLevel)]
|
||||
pub struct PyByteLevel {}
|
||||
#[pymethods]
|
||||
impl ByteLevel {
|
||||
impl PyByteLevel {
|
||||
#[new]
|
||||
#[args(kwargs = "**")]
|
||||
fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PostProcessor)> {
|
||||
let mut byte_level = tk::processors::byte_level::ByteLevel::default();
|
||||
fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PyPostProcessor)> {
|
||||
let mut byte_level = ByteLevel::default();
|
||||
|
||||
if let Some(kwargs) = kwargs {
|
||||
for (key, value) in kwargs {
|
||||
@ -114,11 +148,6 @@ impl ByteLevel {
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok((
|
||||
ByteLevel {},
|
||||
PostProcessor {
|
||||
processor: Container::Owned(Box::new(byte_level)),
|
||||
},
|
||||
))
|
||||
Ok((PyByteLevel {}, PyPostProcessor::new(Arc::new(byte_level))))
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user