Remove Container from Normalizers, replace with Arc.

* prefix the Python types in Rust with Py
* remove unsound Container wrappers, replace with Arc
This commit is contained in:
Sebastian Pütz
2020-07-25 15:44:28 +02:00
committed by Anthony MOI
parent 83a52c8080
commit 08b8c48127
3 changed files with 120 additions and 143 deletions

View File

@ -92,15 +92,15 @@ fn processors(_py: Python, m: &PyModule) -> PyResult<()> {
/// Normalizers Module
#[pymodule]
fn normalizers(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<normalizers::Normalizer>()?;
m.add_class::<normalizers::BertNormalizer>()?;
m.add_class::<normalizers::NFD>()?;
m.add_class::<normalizers::NFKD>()?;
m.add_class::<normalizers::NFC>()?;
m.add_class::<normalizers::NFKC>()?;
m.add_class::<normalizers::Sequence>()?;
m.add_class::<normalizers::Lowercase>()?;
m.add_class::<normalizers::Strip>()?;
m.add_class::<normalizers::PyNormalizer>()?;
m.add_class::<normalizers::PyBertNormalizer>()?;
m.add_class::<normalizers::PyNFD>()?;
m.add_class::<normalizers::PyNFKD>()?;
m.add_class::<normalizers::PyNFC>()?;
m.add_class::<normalizers::PyNFKC>()?;
m.add_class::<normalizers::PySequence>()?;
m.add_class::<normalizers::PyLowercase>()?;
m.add_class::<normalizers::PyStrip>()?;
Ok(())
}

View File

@ -1,40 +1,75 @@
extern crate tokenizers as tk;
use std::sync::Arc;
use super::utils::Container;
use pyo3::exceptions;
use pyo3::prelude::*;
use pyo3::types::*;
#[pyclass(dict, module = "tokenizers.normalizers")]
pub struct Normalizer {
pub normalizer: Container<dyn tk::tokenizer::Normalizer>,
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use tk::normalizers::bert::BertNormalizer;
use tk::normalizers::strip::Strip;
use tk::normalizers::unicode::{NFC, NFD, NFKC, NFKD};
use tk::normalizers::utils::{Lowercase, Sequence};
use tk::{NormalizedString, Normalizer};
use tokenizers as tk;
#[pyclass(dict, module = "tokenizers.normalizers", name=Normalizer)]
#[derive(Clone)]
pub struct PyNormalizer {
pub normalizer: Arc<dyn Normalizer>,
}
impl PyNormalizer {
pub fn new(normalizer: Arc<dyn Normalizer>) -> Self {
PyNormalizer { normalizer }
}
}
#[typetag::serde]
impl Normalizer for PyNormalizer {
fn normalize(&self, normalized: &mut NormalizedString) -> tk::Result<()> {
self.normalizer.normalize(normalized)
}
}
impl Serialize for PyNormalizer {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
self.normalizer.serialize(serializer)
}
}
impl<'de> Deserialize<'de> for PyNormalizer {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
Ok(PyNormalizer::new(Arc::deserialize(deserializer)?))
}
}
#[pymethods]
impl Normalizer {
impl PyNormalizer {
fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
let data = self
.normalizer
.execute(|normalizer| serde_json::to_string(&normalizer))
.map_err(|e| {
exceptions::Exception::py_err(format!(
"Error while attempting to pickle Normalizer: {}",
e.to_string()
))
})?;
let data = serde_json::to_string(&self.normalizer).map_err(|e| {
exceptions::Exception::py_err(format!(
"Error while attempting to pickle Normalizer: {}",
e.to_string()
))
})?;
Ok(PyBytes::new(py, data.as_bytes()).to_object(py))
}
fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
match state.extract::<&PyBytes>(py) {
Ok(s) => {
self.normalizer =
Container::Owned(serde_json::from_slice(s.as_bytes()).map_err(|e| {
exceptions::Exception::py_err(format!(
"Error while attempting to unpickle Normalizer: {}",
e.to_string()
))
})?);
self.normalizer = serde_json::from_slice(s.as_bytes()).map_err(|e| {
exceptions::Exception::py_err(format!(
"Error while attempting to unpickle Normalizer: {}",
e.to_string()
))
})?;
Ok(())
}
Err(e) => Err(e),
@ -42,13 +77,13 @@ impl Normalizer {
}
}
#[pyclass(extends=Normalizer, module = "tokenizers.normalizers")]
pub struct BertNormalizer {}
#[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name=BertNormalizer)]
pub struct PyBertNormalizer {}
#[pymethods]
impl BertNormalizer {
impl PyBertNormalizer {
#[new]
#[args(kwargs = "**")]
fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, Normalizer)> {
fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PyNormalizer)> {
let mut clean_text = true;
let mut handle_chinese_chars = true;
let mut strip_accents = None;
@ -66,108 +101,71 @@ impl BertNormalizer {
}
}
}
Ok((
BertNormalizer {},
Normalizer {
normalizer: Container::Owned(Box::new(tk::normalizers::bert::BertNormalizer::new(
clean_text,
handle_chinese_chars,
strip_accents,
lowercase,
))),
},
))
let normalizer =
BertNormalizer::new(clean_text, handle_chinese_chars, strip_accents, lowercase);
Ok((PyBertNormalizer {}, PyNormalizer::new(Arc::new(normalizer))))
}
}
#[pyclass(extends=Normalizer, module = "tokenizers.normalizers")]
pub struct NFD {}
#[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name=NFD)]
pub struct PyNFD {}
#[pymethods]
impl NFD {
impl PyNFD {
#[new]
fn new() -> PyResult<(Self, Normalizer)> {
Ok((
NFD {},
Normalizer {
normalizer: Container::Owned(Box::new(tk::normalizers::unicode::NFD)),
},
))
fn new() -> PyResult<(Self, PyNormalizer)> {
Ok((PyNFD {}, PyNormalizer::new(Arc::new(NFD))))
}
}
#[pyclass(extends=Normalizer, module = "tokenizers.normalizers")]
pub struct NFKD {}
#[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name=NFKD)]
pub struct PyNFKD {}
#[pymethods]
impl NFKD {
impl PyNFKD {
#[new]
fn new() -> PyResult<(Self, Normalizer)> {
Ok((
NFKD {},
Normalizer {
normalizer: Container::Owned(Box::new(tk::normalizers::unicode::NFKD)),
},
))
fn new() -> PyResult<(Self, PyNormalizer)> {
Ok((PyNFKD {}, PyNormalizer::new(Arc::new(NFKD))))
}
}
#[pyclass(extends=Normalizer, module = "tokenizers.normalizers")]
pub struct NFC {}
#[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name=NFC)]
pub struct PyNFC {}
#[pymethods]
impl NFC {
impl PyNFC {
#[new]
fn new() -> PyResult<(Self, Normalizer)> {
Ok((
NFC {},
Normalizer {
normalizer: Container::Owned(Box::new(tk::normalizers::unicode::NFC)),
},
))
fn new() -> PyResult<(Self, PyNormalizer)> {
Ok((PyNFC {}, PyNormalizer::new(Arc::new(NFC))))
}
}
#[pyclass(extends=Normalizer, module = "tokenizers.normalizers")]
pub struct NFKC {}
#[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name=NFKC)]
pub struct PyNFKC {}
#[pymethods]
impl NFKC {
impl PyNFKC {
#[new]
fn new() -> PyResult<(Self, Normalizer)> {
Ok((
NFKC {},
Normalizer {
normalizer: Container::Owned(Box::new(tk::normalizers::unicode::NFKC)),
},
))
fn new() -> PyResult<(Self, PyNormalizer)> {
Ok((PyNFKC {}, PyNormalizer::new(Arc::new(NFKC))))
}
}
#[pyclass(extends=Normalizer, module = "tokenizers.normalizers")]
pub struct Sequence {}
#[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name=Sequence)]
pub struct PySequence {}
#[pymethods]
impl Sequence {
impl PySequence {
#[new]
fn new(normalizers: &PyList) -> PyResult<(Self, Normalizer)> {
fn new(normalizers: &PyList) -> PyResult<(Self, PyNormalizer)> {
let normalizers = normalizers
.iter()
.map(|n| {
let mut normalizer: PyRefMut<Normalizer> = n.extract()?;
if let Some(normalizer) = normalizer.normalizer.to_pointer() {
Ok(normalizer)
} else {
Err(exceptions::Exception::py_err(
"At least one normalizer is already being used in another Tokenizer",
))
}
let normalizer: PyRef<PyNormalizer> = n.extract()?;
let normalizer = PyNormalizer::new(normalizer.normalizer.clone());
let boxed = Box::new(normalizer);
Ok(boxed as Box<dyn Normalizer>)
})
.collect::<PyResult<_>>()?;
Ok((
Sequence {},
Normalizer {
normalizer: Container::Owned(Box::new(tk::normalizers::utils::Sequence::new(
normalizers,
))),
},
PySequence {},
PyNormalizer::new(Arc::new(Sequence::new(normalizers))),
))
}
@ -176,28 +174,23 @@ impl Sequence {
}
}
#[pyclass(extends=Normalizer, module = "tokenizers.normalizers")]
pub struct Lowercase {}
#[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name=Lowercase)]
pub struct PyLowercase {}
#[pymethods]
impl Lowercase {
impl PyLowercase {
#[new]
fn new() -> PyResult<(Self, Normalizer)> {
Ok((
Lowercase {},
Normalizer {
normalizer: Container::Owned(Box::new(tk::normalizers::utils::Lowercase)),
},
))
fn new() -> PyResult<(Self, PyNormalizer)> {
Ok((PyLowercase {}, PyNormalizer::new(Arc::new(Lowercase))))
}
}
#[pyclass(extends=Normalizer, module = "tokenizers.normalizers")]
pub struct Strip {}
#[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name=Strip)]
pub struct PyStrip {}
#[pymethods]
impl Strip {
impl PyStrip {
#[new]
#[args(kwargs = "**")]
fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, Normalizer)> {
fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PyNormalizer)> {
let mut left = true;
let mut right = true;
@ -211,12 +204,8 @@ impl Strip {
}
Ok((
Strip {},
Normalizer {
normalizer: Container::Owned(Box::new(tk::normalizers::strip::Strip::new(
left, right,
))),
},
PyStrip {},
PyNormalizer::new(Arc::new(Strip::new(left, right))),
))
}
}

View File

@ -16,7 +16,7 @@ use super::decoders::Decoder;
use super::encoding::PyEncoding;
use super::error::{PyError, ToPyResult};
use super::models::PyModel;
use super::normalizers::Normalizer;
use super::normalizers::PyNormalizer;
use super::pre_tokenizers::PreTokenizer;
use super::processors::PostProcessor;
use super::trainers::PyTrainer;
@ -268,7 +268,7 @@ impl From<PreTokenizedEncodeInput> for tk::tokenizer::EncodeInput {
}
}
type TokenizerImpl = Tokenizer<PyModel>;
type TokenizerImpl = Tokenizer<PyModel, PyNormalizer>;
#[pyclass(dict, module = "tokenizers")]
pub struct PyTokenizer {
@ -707,25 +707,13 @@ impl PyTokenizer {
}
#[getter]
fn get_normalizer(&self) -> PyResult<Option<Normalizer>> {
Ok(self
.tokenizer
.get_normalizer()
.map(|normalizer| Normalizer {
normalizer: Container::from_ref(normalizer),
}))
fn get_normalizer(&self) -> Option<PyNormalizer> {
self.tokenizer.get_normalizer().cloned()
}
#[setter]
fn set_normalizer(&mut self, mut normalizer: PyRefMut<Normalizer>) -> PyResult<()> {
if let Some(normalizer) = normalizer.normalizer.to_pointer() {
self.tokenizer.with_normalizer(normalizer);
Ok(())
} else {
Err(exceptions::Exception::py_err(
"The Normalizer is already being used in another Tokenizer",
))
}
fn set_normalizer(&mut self, normalizer: PyRef<PyNormalizer>) {
self.tokenizer.with_normalizer(normalizer.clone());
}
#[getter]