Remove Container from Normalizers, replace with Arc.

* prefix the Python types in Rust with Py
* remove unsound Container wrappers, replace with Arc
This commit is contained in:
Sebastian Pütz
2020-07-25 15:44:28 +02:00
committed by Anthony MOI
parent 83a52c8080
commit 08b8c48127
3 changed files with 120 additions and 143 deletions

View File

@ -92,15 +92,15 @@ fn processors(_py: Python, m: &PyModule) -> PyResult<()> {
/// Normalizers Module /// Normalizers Module
#[pymodule] #[pymodule]
fn normalizers(_py: Python, m: &PyModule) -> PyResult<()> { fn normalizers(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<normalizers::Normalizer>()?; m.add_class::<normalizers::PyNormalizer>()?;
m.add_class::<normalizers::BertNormalizer>()?; m.add_class::<normalizers::PyBertNormalizer>()?;
m.add_class::<normalizers::NFD>()?; m.add_class::<normalizers::PyNFD>()?;
m.add_class::<normalizers::NFKD>()?; m.add_class::<normalizers::PyNFKD>()?;
m.add_class::<normalizers::NFC>()?; m.add_class::<normalizers::PyNFC>()?;
m.add_class::<normalizers::NFKC>()?; m.add_class::<normalizers::PyNFKC>()?;
m.add_class::<normalizers::Sequence>()?; m.add_class::<normalizers::PySequence>()?;
m.add_class::<normalizers::Lowercase>()?; m.add_class::<normalizers::PyLowercase>()?;
m.add_class::<normalizers::Strip>()?; m.add_class::<normalizers::PyStrip>()?;
Ok(()) Ok(())
} }

View File

@ -1,22 +1,58 @@
extern crate tokenizers as tk; use std::sync::Arc;
use super::utils::Container;
use pyo3::exceptions; use pyo3::exceptions;
use pyo3::prelude::*; use pyo3::prelude::*;
use pyo3::types::*; use pyo3::types::*;
#[pyclass(dict, module = "tokenizers.normalizers")] use serde::{Deserialize, Deserializer, Serialize, Serializer};
pub struct Normalizer { use tk::normalizers::bert::BertNormalizer;
pub normalizer: Container<dyn tk::tokenizer::Normalizer>, use tk::normalizers::strip::Strip;
use tk::normalizers::unicode::{NFC, NFD, NFKC, NFKD};
use tk::normalizers::utils::{Lowercase, Sequence};
use tk::{NormalizedString, Normalizer};
use tokenizers as tk;
#[pyclass(dict, module = "tokenizers.normalizers", name=Normalizer)]
#[derive(Clone)]
pub struct PyNormalizer {
pub normalizer: Arc<dyn Normalizer>,
}
impl PyNormalizer {
pub fn new(normalizer: Arc<dyn Normalizer>) -> Self {
PyNormalizer { normalizer }
}
}
#[typetag::serde]
impl Normalizer for PyNormalizer {
fn normalize(&self, normalized: &mut NormalizedString) -> tk::Result<()> {
self.normalizer.normalize(normalized)
}
}
impl Serialize for PyNormalizer {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
self.normalizer.serialize(serializer)
}
}
impl<'de> Deserialize<'de> for PyNormalizer {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
Ok(PyNormalizer::new(Arc::deserialize(deserializer)?))
}
} }
#[pymethods] #[pymethods]
impl Normalizer { impl PyNormalizer {
fn __getstate__(&self, py: Python) -> PyResult<PyObject> { fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
let data = self let data = serde_json::to_string(&self.normalizer).map_err(|e| {
.normalizer
.execute(|normalizer| serde_json::to_string(&normalizer))
.map_err(|e| {
exceptions::Exception::py_err(format!( exceptions::Exception::py_err(format!(
"Error while attempting to pickle Normalizer: {}", "Error while attempting to pickle Normalizer: {}",
e.to_string() e.to_string()
@ -28,13 +64,12 @@ impl Normalizer {
fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
match state.extract::<&PyBytes>(py) { match state.extract::<&PyBytes>(py) {
Ok(s) => { Ok(s) => {
self.normalizer = self.normalizer = serde_json::from_slice(s.as_bytes()).map_err(|e| {
Container::Owned(serde_json::from_slice(s.as_bytes()).map_err(|e| {
exceptions::Exception::py_err(format!( exceptions::Exception::py_err(format!(
"Error while attempting to unpickle Normalizer: {}", "Error while attempting to unpickle Normalizer: {}",
e.to_string() e.to_string()
)) ))
})?); })?;
Ok(()) Ok(())
} }
Err(e) => Err(e), Err(e) => Err(e),
@ -42,13 +77,13 @@ impl Normalizer {
} }
} }
#[pyclass(extends=Normalizer, module = "tokenizers.normalizers")] #[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name=BertNormalizer)]
pub struct BertNormalizer {} pub struct PyBertNormalizer {}
#[pymethods] #[pymethods]
impl BertNormalizer { impl PyBertNormalizer {
#[new] #[new]
#[args(kwargs = "**")] #[args(kwargs = "**")]
fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, Normalizer)> { fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PyNormalizer)> {
let mut clean_text = true; let mut clean_text = true;
let mut handle_chinese_chars = true; let mut handle_chinese_chars = true;
let mut strip_accents = None; let mut strip_accents = None;
@ -66,108 +101,71 @@ impl BertNormalizer {
} }
} }
} }
let normalizer =
Ok(( BertNormalizer::new(clean_text, handle_chinese_chars, strip_accents, lowercase);
BertNormalizer {}, Ok((PyBertNormalizer {}, PyNormalizer::new(Arc::new(normalizer))))
Normalizer {
normalizer: Container::Owned(Box::new(tk::normalizers::bert::BertNormalizer::new(
clean_text,
handle_chinese_chars,
strip_accents,
lowercase,
))),
},
))
} }
} }
#[pyclass(extends=Normalizer, module = "tokenizers.normalizers")] #[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name=NFD)]
pub struct NFD {} pub struct PyNFD {}
#[pymethods] #[pymethods]
impl NFD { impl PyNFD {
#[new] #[new]
fn new() -> PyResult<(Self, Normalizer)> { fn new() -> PyResult<(Self, PyNormalizer)> {
Ok(( Ok((PyNFD {}, PyNormalizer::new(Arc::new(NFD))))
NFD {},
Normalizer {
normalizer: Container::Owned(Box::new(tk::normalizers::unicode::NFD)),
},
))
} }
} }
#[pyclass(extends=Normalizer, module = "tokenizers.normalizers")] #[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name=NFKD)]
pub struct NFKD {} pub struct PyNFKD {}
#[pymethods] #[pymethods]
impl NFKD { impl PyNFKD {
#[new] #[new]
fn new() -> PyResult<(Self, Normalizer)> { fn new() -> PyResult<(Self, PyNormalizer)> {
Ok(( Ok((PyNFKD {}, PyNormalizer::new(Arc::new(NFKD))))
NFKD {},
Normalizer {
normalizer: Container::Owned(Box::new(tk::normalizers::unicode::NFKD)),
},
))
} }
} }
#[pyclass(extends=Normalizer, module = "tokenizers.normalizers")] #[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name=NFC)]
pub struct NFC {} pub struct PyNFC {}
#[pymethods] #[pymethods]
impl NFC { impl PyNFC {
#[new] #[new]
fn new() -> PyResult<(Self, Normalizer)> { fn new() -> PyResult<(Self, PyNormalizer)> {
Ok(( Ok((PyNFC {}, PyNormalizer::new(Arc::new(NFC))))
NFC {},
Normalizer {
normalizer: Container::Owned(Box::new(tk::normalizers::unicode::NFC)),
},
))
} }
} }
#[pyclass(extends=Normalizer, module = "tokenizers.normalizers")] #[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name=NFKC)]
pub struct NFKC {} pub struct PyNFKC {}
#[pymethods] #[pymethods]
impl NFKC { impl PyNFKC {
#[new] #[new]
fn new() -> PyResult<(Self, Normalizer)> { fn new() -> PyResult<(Self, PyNormalizer)> {
Ok(( Ok((PyNFKC {}, PyNormalizer::new(Arc::new(NFKC))))
NFKC {},
Normalizer {
normalizer: Container::Owned(Box::new(tk::normalizers::unicode::NFKC)),
},
))
} }
} }
#[pyclass(extends=Normalizer, module = "tokenizers.normalizers")] #[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name=Sequence)]
pub struct Sequence {} pub struct PySequence {}
#[pymethods] #[pymethods]
impl Sequence { impl PySequence {
#[new] #[new]
fn new(normalizers: &PyList) -> PyResult<(Self, Normalizer)> { fn new(normalizers: &PyList) -> PyResult<(Self, PyNormalizer)> {
let normalizers = normalizers let normalizers = normalizers
.iter() .iter()
.map(|n| { .map(|n| {
let mut normalizer: PyRefMut<Normalizer> = n.extract()?; let normalizer: PyRef<PyNormalizer> = n.extract()?;
if let Some(normalizer) = normalizer.normalizer.to_pointer() { let normalizer = PyNormalizer::new(normalizer.normalizer.clone());
Ok(normalizer) let boxed = Box::new(normalizer);
} else { Ok(boxed as Box<dyn Normalizer>)
Err(exceptions::Exception::py_err(
"At least one normalizer is already being used in another Tokenizer",
))
}
}) })
.collect::<PyResult<_>>()?; .collect::<PyResult<_>>()?;
Ok(( Ok((
Sequence {}, PySequence {},
Normalizer { PyNormalizer::new(Arc::new(Sequence::new(normalizers))),
normalizer: Container::Owned(Box::new(tk::normalizers::utils::Sequence::new(
normalizers,
))),
},
)) ))
} }
@ -176,28 +174,23 @@ impl Sequence {
} }
} }
#[pyclass(extends=Normalizer, module = "tokenizers.normalizers")] #[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name=Lowercase)]
pub struct Lowercase {} pub struct PyLowercase {}
#[pymethods] #[pymethods]
impl Lowercase { impl PyLowercase {
#[new] #[new]
fn new() -> PyResult<(Self, Normalizer)> { fn new() -> PyResult<(Self, PyNormalizer)> {
Ok(( Ok((PyLowercase {}, PyNormalizer::new(Arc::new(Lowercase))))
Lowercase {},
Normalizer {
normalizer: Container::Owned(Box::new(tk::normalizers::utils::Lowercase)),
},
))
} }
} }
#[pyclass(extends=Normalizer, module = "tokenizers.normalizers")] #[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name=Strip)]
pub struct Strip {} pub struct PyStrip {}
#[pymethods] #[pymethods]
impl Strip { impl PyStrip {
#[new] #[new]
#[args(kwargs = "**")] #[args(kwargs = "**")]
fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, Normalizer)> { fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PyNormalizer)> {
let mut left = true; let mut left = true;
let mut right = true; let mut right = true;
@ -211,12 +204,8 @@ impl Strip {
} }
Ok(( Ok((
Strip {}, PyStrip {},
Normalizer { PyNormalizer::new(Arc::new(Strip::new(left, right))),
normalizer: Container::Owned(Box::new(tk::normalizers::strip::Strip::new(
left, right,
))),
},
)) ))
} }
} }

View File

@ -16,7 +16,7 @@ use super::decoders::Decoder;
use super::encoding::PyEncoding; use super::encoding::PyEncoding;
use super::error::{PyError, ToPyResult}; use super::error::{PyError, ToPyResult};
use super::models::PyModel; use super::models::PyModel;
use super::normalizers::Normalizer; use super::normalizers::PyNormalizer;
use super::pre_tokenizers::PreTokenizer; use super::pre_tokenizers::PreTokenizer;
use super::processors::PostProcessor; use super::processors::PostProcessor;
use super::trainers::PyTrainer; use super::trainers::PyTrainer;
@ -268,7 +268,7 @@ impl From<PreTokenizedEncodeInput> for tk::tokenizer::EncodeInput {
} }
} }
type TokenizerImpl = Tokenizer<PyModel>; type TokenizerImpl = Tokenizer<PyModel, PyNormalizer>;
#[pyclass(dict, module = "tokenizers")] #[pyclass(dict, module = "tokenizers")]
pub struct PyTokenizer { pub struct PyTokenizer {
@ -707,25 +707,13 @@ impl PyTokenizer {
} }
#[getter] #[getter]
fn get_normalizer(&self) -> PyResult<Option<Normalizer>> { fn get_normalizer(&self) -> Option<PyNormalizer> {
Ok(self self.tokenizer.get_normalizer().cloned()
.tokenizer
.get_normalizer()
.map(|normalizer| Normalizer {
normalizer: Container::from_ref(normalizer),
}))
} }
#[setter] #[setter]
fn set_normalizer(&mut self, mut normalizer: PyRefMut<Normalizer>) -> PyResult<()> { fn set_normalizer(&mut self, normalizer: PyRef<PyNormalizer>) {
if let Some(normalizer) = normalizer.normalizer.to_pointer() { self.tokenizer.with_normalizer(normalizer.clone());
self.tokenizer.with_normalizer(normalizer);
Ok(())
} else {
Err(exceptions::Exception::py_err(
"The Normalizer is already being used in another Tokenizer",
))
}
} }
#[getter] #[getter]