Python - PyNormalizer & PyPreTokenizer use a RwLock

This commit is contained in:
Anthony MOI
2020-11-16 11:37:28 -05:00
committed by Anthony MOI
parent 76d3b2128b
commit c22cfc31f9
3 changed files with 66 additions and 58 deletions

View File

@ -1,4 +1,4 @@
use std::sync::Arc;
use std::sync::{Arc, RwLock};
use pyo3::exceptions;
use pyo3::prelude::*;
@ -36,7 +36,7 @@ impl PyNormalizer {
let py = gil.python();
Ok(match self.normalizer {
PyNormalizerTypeWrapper::Sequence(_) => Py::new(py, (PySequence {}, base))?.into_py(py),
PyNormalizerTypeWrapper::Single(ref inner) => match inner.as_ref() {
PyNormalizerTypeWrapper::Single(ref inner) => match &*inner.as_ref().read().unwrap() {
PyNormalizerWrapper::Custom(_) => Py::new(py, base)?.into_py(py),
PyNormalizerWrapper::Wrapped(ref inner) => match inner {
NormalizerWrapper::Sequence(_) => {
@ -421,8 +421,8 @@ impl Serialize for PyNormalizerWrapper {
#[derive(Clone, Deserialize)]
#[serde(untagged)]
pub(crate) enum PyNormalizerTypeWrapper {
Sequence(Vec<Arc<PyNormalizerWrapper>>),
Single(Arc<PyNormalizerWrapper>),
Sequence(Vec<Arc<RwLock<PyNormalizerWrapper>>>),
Single(Arc<RwLock<PyNormalizerWrapper>>),
}
impl Serialize for PyNormalizerTypeWrapper {
@ -456,7 +456,7 @@ where
I: Into<PyNormalizerWrapper>,
{
fn from(norm: I) -> Self {
PyNormalizerTypeWrapper::Single(Arc::new(norm.into()))
PyNormalizerTypeWrapper::Single(Arc::new(RwLock::new(norm.into())))
}
}
@ -474,10 +474,11 @@ where
impl Normalizer for PyNormalizerTypeWrapper {
fn normalize(&self, normalized: &mut NormalizedString) -> tk::Result<()> {
match self {
PyNormalizerTypeWrapper::Single(inner) => inner.normalize(normalized),
PyNormalizerTypeWrapper::Sequence(inner) => {
inner.iter().map(|n| n.normalize(normalized)).collect()
}
PyNormalizerTypeWrapper::Single(inner) => inner.read().unwrap().normalize(normalized),
PyNormalizerTypeWrapper::Sequence(inner) => inner
.iter()
.map(|n| n.read().unwrap().normalize(normalized))
.collect(),
}
}
}
@ -520,7 +521,7 @@ mod test {
assert_eq!(py_ser, rs_ser);
let py_norm: PyNormalizer = serde_json::from_str(&rs_ser).unwrap();
match py_norm.normalizer {
PyNormalizerTypeWrapper::Single(inner) => match inner.as_ref() {
PyNormalizerTypeWrapper::Single(inner) => match *inner.as_ref().read().unwrap() {
PyNormalizerWrapper::Wrapped(NormalizerWrapper::NFKC(_)) => {}
_ => panic!("Expected NFKC"),
},
@ -547,7 +548,7 @@ mod test {
let string = r#"{"type": "NFKC"}"#;
let normalizer: PyNormalizer = serde_json::from_str(&string).unwrap();
match normalizer.normalizer {
PyNormalizerTypeWrapper::Single(inner) => match inner.as_ref() {
PyNormalizerTypeWrapper::Single(inner) => match *inner.as_ref().read().unwrap() {
PyNormalizerWrapper::Wrapped(NormalizerWrapper::NFKC(_)) => {}
_ => panic!("Expected NFKC"),
},
@ -558,7 +559,7 @@ mod test {
let normalizer: PyNormalizer = serde_json::from_str(&sequence_string).unwrap();
match normalizer.normalizer {
PyNormalizerTypeWrapper::Single(inner) => match inner.as_ref() {
PyNormalizerTypeWrapper::Single(inner) => match &*inner.as_ref().read().unwrap() {
PyNormalizerWrapper::Wrapped(NormalizerWrapper::Sequence(sequence)) => {
let normalizers = sequence.get_normalizers();
assert_eq!(normalizers.len(), 1);
@ -570,6 +571,6 @@ mod test {
_ => panic!("Expected sequence"),
},
_ => panic!("Expected single"),
}
};
}
}