mirror of
https://github.com/mii443/tokenizers.git
synced 2025-12-16 17:18:43 +00:00
Remove Container from PostProcessors, replace with Arc.
* prefix the Python types in Rust with Py. * remove unsound Container wrappers, replace with Arc.
This commit is contained in:
committed by
Anthony MOI
parent
0f81997351
commit
11e86a16c5
@@ -82,10 +82,10 @@ fn decoders(_py: Python, m: &PyModule) -> PyResult<()> {
|
|||||||
/// Processors Module
|
/// Processors Module
|
||||||
#[pymodule]
|
#[pymodule]
|
||||||
fn processors(_py: Python, m: &PyModule) -> PyResult<()> {
|
fn processors(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||||
m.add_class::<processors::PostProcessor>()?;
|
m.add_class::<processors::PyPostProcessor>()?;
|
||||||
m.add_class::<processors::BertProcessing>()?;
|
m.add_class::<processors::PyBertProcessing>()?;
|
||||||
m.add_class::<processors::RobertaProcessing>()?;
|
m.add_class::<processors::PyRobertaProcessing>()?;
|
||||||
m.add_class::<processors::ByteLevel>()?;
|
m.add_class::<processors::PyByteLevel>()?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,40 +1,84 @@
|
|||||||
extern crate tokenizers as tk;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use super::utils::Container;
|
|
||||||
use pyo3::exceptions;
|
use pyo3::exceptions;
|
||||||
use pyo3::prelude::*;
|
use pyo3::prelude::*;
|
||||||
use pyo3::types::*;
|
use pyo3::types::*;
|
||||||
|
|
||||||
#[pyclass(dict, module = "tokenizers.processors")]
|
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||||
pub struct PostProcessor {
|
use tk::processors::bert::BertProcessing;
|
||||||
pub processor: Container<dyn tk::tokenizer::PostProcessor>,
|
use tk::processors::byte_level::ByteLevel;
|
||||||
|
use tk::processors::roberta::RobertaProcessing;
|
||||||
|
use tk::{Encoding, PostProcessor};
|
||||||
|
use tokenizers as tk;
|
||||||
|
|
||||||
|
#[pyclass(dict, module = "tokenizers.processors", name=PostProcessor)]
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct PyPostProcessor {
|
||||||
|
pub processor: Arc<dyn PostProcessor>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PyPostProcessor {
|
||||||
|
pub fn new(processor: Arc<dyn PostProcessor>) -> Self {
|
||||||
|
PyPostProcessor { processor }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[typetag::serde]
|
||||||
|
impl PostProcessor for PyPostProcessor {
|
||||||
|
fn added_tokens(&self, is_pair: bool) -> usize {
|
||||||
|
self.processor.added_tokens(is_pair)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn process(
|
||||||
|
&self,
|
||||||
|
encoding: Encoding,
|
||||||
|
pair_encoding: Option<Encoding>,
|
||||||
|
add_special_tokens: bool,
|
||||||
|
) -> tk::Result<Encoding> {
|
||||||
|
self.processor
|
||||||
|
.process(encoding, pair_encoding, add_special_tokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Serialize for PyPostProcessor {
|
||||||
|
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||||
|
where
|
||||||
|
S: Serializer,
|
||||||
|
{
|
||||||
|
self.processor.serialize(serializer)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'de> Deserialize<'de> for PyPostProcessor {
|
||||||
|
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||||
|
where
|
||||||
|
D: Deserializer<'de>,
|
||||||
|
{
|
||||||
|
Ok(PyPostProcessor::new(Arc::deserialize(deserializer)?))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[pymethods]
|
#[pymethods]
|
||||||
impl PostProcessor {
|
impl PyPostProcessor {
|
||||||
fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
|
fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
|
||||||
let data = self
|
let data = serde_json::to_string(self.processor.as_ref()).map_err(|e| {
|
||||||
.processor
|
exceptions::Exception::py_err(format!(
|
||||||
.execute(|processor| serde_json::to_string(&processor))
|
"Error while attempting to pickle PostProcessor: {}",
|
||||||
.map_err(|e| {
|
e.to_string()
|
||||||
exceptions::Exception::py_err(format!(
|
))
|
||||||
"Error while attempting to pickle PostProcessor: {}",
|
})?;
|
||||||
e.to_string()
|
|
||||||
))
|
|
||||||
})?;
|
|
||||||
Ok(PyBytes::new(py, data.as_bytes()).to_object(py))
|
Ok(PyBytes::new(py, data.as_bytes()).to_object(py))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
|
fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
|
||||||
match state.extract::<&PyBytes>(py) {
|
match state.extract::<&PyBytes>(py) {
|
||||||
Ok(s) => {
|
Ok(s) => {
|
||||||
self.processor =
|
self.processor = serde_json::from_slice(s.as_bytes()).map_err(|e| {
|
||||||
Container::Owned(serde_json::from_slice(s.as_bytes()).map_err(|e| {
|
exceptions::Exception::py_err(format!(
|
||||||
exceptions::Exception::py_err(format!(
|
"Error while attempting to unpickle PostProcessor: {}",
|
||||||
"Error while attempting to unpickle PostProcessor: {}",
|
e.to_string()
|
||||||
e.to_string()
|
))
|
||||||
))
|
})?;
|
||||||
})?);
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
Err(e) => Err(e),
|
Err(e) => Err(e),
|
||||||
@@ -42,23 +86,19 @@ impl PostProcessor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn num_special_tokens_to_add(&self, is_pair: bool) -> usize {
|
fn num_special_tokens_to_add(&self, is_pair: bool) -> usize {
|
||||||
self.processor.execute(|p| p.added_tokens(is_pair))
|
self.processor.added_tokens(is_pair)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[pyclass(extends=PostProcessor, module = "tokenizers.processors")]
|
#[pyclass(extends=PyPostProcessor, module = "tokenizers.processors", name=BertProcessing)]
|
||||||
pub struct BertProcessing {}
|
pub struct PyBertProcessing {}
|
||||||
#[pymethods]
|
#[pymethods]
|
||||||
impl BertProcessing {
|
impl PyBertProcessing {
|
||||||
#[new]
|
#[new]
|
||||||
fn new(sep: (String, u32), cls: (String, u32)) -> PyResult<(Self, PostProcessor)> {
|
fn new(sep: (String, u32), cls: (String, u32)) -> PyResult<(Self, PyPostProcessor)> {
|
||||||
Ok((
|
Ok((
|
||||||
BertProcessing {},
|
PyBertProcessing {},
|
||||||
PostProcessor {
|
PyPostProcessor::new(Arc::new(BertProcessing::new(sep, cls))),
|
||||||
processor: Container::Owned(Box::new(tk::processors::bert::BertProcessing::new(
|
|
||||||
sep, cls,
|
|
||||||
))),
|
|
||||||
},
|
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -67,10 +107,10 @@ impl BertProcessing {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[pyclass(extends=PostProcessor, module = "tokenizers.processors")]
|
#[pyclass(extends=PyPostProcessor, module = "tokenizers.processors", name=RobertaProcessing)]
|
||||||
pub struct RobertaProcessing {}
|
pub struct PyRobertaProcessing {}
|
||||||
#[pymethods]
|
#[pymethods]
|
||||||
impl RobertaProcessing {
|
impl PyRobertaProcessing {
|
||||||
#[new]
|
#[new]
|
||||||
#[args(trim_offsets = true, add_prefix_space = true)]
|
#[args(trim_offsets = true, add_prefix_space = true)]
|
||||||
fn new(
|
fn new(
|
||||||
@@ -78,17 +118,11 @@ impl RobertaProcessing {
|
|||||||
cls: (String, u32),
|
cls: (String, u32),
|
||||||
trim_offsets: bool,
|
trim_offsets: bool,
|
||||||
add_prefix_space: bool,
|
add_prefix_space: bool,
|
||||||
) -> PyResult<(Self, PostProcessor)> {
|
) -> PyResult<(Self, PyPostProcessor)> {
|
||||||
Ok((
|
let proc = RobertaProcessing::new(sep, cls)
|
||||||
RobertaProcessing {},
|
.trim_offsets(trim_offsets)
|
||||||
PostProcessor {
|
.add_prefix_space(add_prefix_space);
|
||||||
processor: Container::Owned(Box::new(
|
Ok((PyRobertaProcessing {}, PyPostProcessor::new(Arc::new(proc))))
|
||||||
tk::processors::roberta::RobertaProcessing::new(sep, cls)
|
|
||||||
.trim_offsets(trim_offsets)
|
|
||||||
.add_prefix_space(add_prefix_space),
|
|
||||||
)),
|
|
||||||
},
|
|
||||||
))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn __getnewargs__<'p>(&self, py: Python<'p>) -> PyResult<&'p PyTuple> {
|
fn __getnewargs__<'p>(&self, py: Python<'p>) -> PyResult<&'p PyTuple> {
|
||||||
@@ -96,14 +130,14 @@ impl RobertaProcessing {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[pyclass(extends=PostProcessor, module = "tokenizers.processors")]
|
#[pyclass(extends=PyPostProcessor, module = "tokenizers.processors", name=ByteLevel)]
|
||||||
pub struct ByteLevel {}
|
pub struct PyByteLevel {}
|
||||||
#[pymethods]
|
#[pymethods]
|
||||||
impl ByteLevel {
|
impl PyByteLevel {
|
||||||
#[new]
|
#[new]
|
||||||
#[args(kwargs = "**")]
|
#[args(kwargs = "**")]
|
||||||
fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PostProcessor)> {
|
fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PyPostProcessor)> {
|
||||||
let mut byte_level = tk::processors::byte_level::ByteLevel::default();
|
let mut byte_level = ByteLevel::default();
|
||||||
|
|
||||||
if let Some(kwargs) = kwargs {
|
if let Some(kwargs) = kwargs {
|
||||||
for (key, value) in kwargs {
|
for (key, value) in kwargs {
|
||||||
@@ -114,11 +148,6 @@ impl ByteLevel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok((
|
Ok((PyByteLevel {}, PyPostProcessor::new(Arc::new(byte_level))))
|
||||||
ByteLevel {},
|
|
||||||
PostProcessor {
|
|
||||||
processor: Container::Owned(Box::new(byte_level)),
|
|
||||||
},
|
|
||||||
))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ use pyo3::types::*;
|
|||||||
use pyo3::PyObjectProtocol;
|
use pyo3::PyObjectProtocol;
|
||||||
use tk::models::bpe::BPE;
|
use tk::models::bpe::BPE;
|
||||||
use tk::tokenizer::{
|
use tk::tokenizer::{
|
||||||
PaddingDirection, PaddingParams, PaddingStrategy, Tokenizer, TruncationParams,
|
PaddingDirection, PaddingParams, PaddingStrategy, PostProcessor, Tokenizer, TruncationParams,
|
||||||
TruncationStrategy,
|
TruncationStrategy,
|
||||||
};
|
};
|
||||||
use tokenizers as tk;
|
use tokenizers as tk;
|
||||||
@@ -18,9 +18,9 @@ use super::error::{PyError, ToPyResult};
|
|||||||
use super::models::PyModel;
|
use super::models::PyModel;
|
||||||
use super::normalizers::PyNormalizer;
|
use super::normalizers::PyNormalizer;
|
||||||
use super::pre_tokenizers::PyPreTokenizer;
|
use super::pre_tokenizers::PyPreTokenizer;
|
||||||
use super::processors::PostProcessor;
|
|
||||||
use super::trainers::PyTrainer;
|
use super::trainers::PyTrainer;
|
||||||
use super::utils::Container;
|
use super::utils::Container;
|
||||||
|
use crate::processors::PyPostProcessor;
|
||||||
|
|
||||||
#[pyclass(dict, module = "tokenizers", name=AddedToken)]
|
#[pyclass(dict, module = "tokenizers", name=AddedToken)]
|
||||||
pub struct PyAddedToken {
|
pub struct PyAddedToken {
|
||||||
@@ -268,9 +268,9 @@ impl From<PreTokenizedEncodeInput> for tk::tokenizer::EncodeInput {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type TokenizerImpl = Tokenizer<PyModel, PyNormalizer, PyPreTokenizer>;
|
type TokenizerImpl = Tokenizer<PyModel, PyNormalizer, PyPreTokenizer, PyPostProcessor>;
|
||||||
|
|
||||||
#[pyclass(dict, module = "tokenizers")]
|
#[pyclass(dict, module = "tokenizers", name=Tokenizer)]
|
||||||
pub struct PyTokenizer {
|
pub struct PyTokenizer {
|
||||||
tokenizer: TokenizerImpl,
|
tokenizer: TokenizerImpl,
|
||||||
}
|
}
|
||||||
@@ -362,7 +362,7 @@ impl PyTokenizer {
|
|||||||
Ok(self
|
Ok(self
|
||||||
.tokenizer
|
.tokenizer
|
||||||
.get_post_processor()
|
.get_post_processor()
|
||||||
.map_or(0, |p| p.as_ref().added_tokens(is_pair)))
|
.map_or(0, |p| p.added_tokens(is_pair)))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[args(with_added_tokens = true)]
|
#[args(with_added_tokens = true)]
|
||||||
@@ -727,25 +727,13 @@ impl PyTokenizer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[getter]
|
#[getter]
|
||||||
fn get_post_processor(&self) -> PyResult<Option<PostProcessor>> {
|
fn get_post_processor(&self) -> Option<PyPostProcessor> {
|
||||||
Ok(self
|
self.tokenizer.get_post_processor().cloned()
|
||||||
.tokenizer
|
|
||||||
.get_post_processor()
|
|
||||||
.map(|processor| PostProcessor {
|
|
||||||
processor: Container::from_ref(processor),
|
|
||||||
}))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[setter]
|
#[setter]
|
||||||
fn set_post_processor(&mut self, mut processor: PyRefMut<PostProcessor>) -> PyResult<()> {
|
fn set_post_processor(&mut self, processor: PyRef<PyPostProcessor>) {
|
||||||
if let Some(processor) = processor.processor.to_pointer() {
|
self.tokenizer.with_post_processor(processor.clone());
|
||||||
self.tokenizer.with_post_processor(processor);
|
|
||||||
Ok(())
|
|
||||||
} else {
|
|
||||||
Err(exceptions::Exception::py_err(
|
|
||||||
"The Processor is already being used in another Tokenizer",
|
|
||||||
))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[getter]
|
#[getter]
|
||||||
|
|||||||
Reference in New Issue
Block a user