mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-23 00:35:35 +00:00
Python - PyNormalizer & PyPreTokenizer use a RwLock
This commit is contained in:
@ -1,4 +1,4 @@
|
||||
use std::sync::Arc;
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
use pyo3::exceptions;
|
||||
use pyo3::prelude::*;
|
||||
@ -36,7 +36,7 @@ impl PyNormalizer {
|
||||
let py = gil.python();
|
||||
Ok(match self.normalizer {
|
||||
PyNormalizerTypeWrapper::Sequence(_) => Py::new(py, (PySequence {}, base))?.into_py(py),
|
||||
PyNormalizerTypeWrapper::Single(ref inner) => match inner.as_ref() {
|
||||
PyNormalizerTypeWrapper::Single(ref inner) => match &*inner.as_ref().read().unwrap() {
|
||||
PyNormalizerWrapper::Custom(_) => Py::new(py, base)?.into_py(py),
|
||||
PyNormalizerWrapper::Wrapped(ref inner) => match inner {
|
||||
NormalizerWrapper::Sequence(_) => {
|
||||
@ -421,8 +421,8 @@ impl Serialize for PyNormalizerWrapper {
|
||||
#[derive(Clone, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub(crate) enum PyNormalizerTypeWrapper {
|
||||
Sequence(Vec<Arc<PyNormalizerWrapper>>),
|
||||
Single(Arc<PyNormalizerWrapper>),
|
||||
Sequence(Vec<Arc<RwLock<PyNormalizerWrapper>>>),
|
||||
Single(Arc<RwLock<PyNormalizerWrapper>>),
|
||||
}
|
||||
|
||||
impl Serialize for PyNormalizerTypeWrapper {
|
||||
@ -456,7 +456,7 @@ where
|
||||
I: Into<PyNormalizerWrapper>,
|
||||
{
|
||||
fn from(norm: I) -> Self {
|
||||
PyNormalizerTypeWrapper::Single(Arc::new(norm.into()))
|
||||
PyNormalizerTypeWrapper::Single(Arc::new(RwLock::new(norm.into())))
|
||||
}
|
||||
}
|
||||
|
||||
@ -474,10 +474,11 @@ where
|
||||
impl Normalizer for PyNormalizerTypeWrapper {
|
||||
fn normalize(&self, normalized: &mut NormalizedString) -> tk::Result<()> {
|
||||
match self {
|
||||
PyNormalizerTypeWrapper::Single(inner) => inner.normalize(normalized),
|
||||
PyNormalizerTypeWrapper::Sequence(inner) => {
|
||||
inner.iter().map(|n| n.normalize(normalized)).collect()
|
||||
}
|
||||
PyNormalizerTypeWrapper::Single(inner) => inner.read().unwrap().normalize(normalized),
|
||||
PyNormalizerTypeWrapper::Sequence(inner) => inner
|
||||
.iter()
|
||||
.map(|n| n.read().unwrap().normalize(normalized))
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -520,7 +521,7 @@ mod test {
|
||||
assert_eq!(py_ser, rs_ser);
|
||||
let py_norm: PyNormalizer = serde_json::from_str(&rs_ser).unwrap();
|
||||
match py_norm.normalizer {
|
||||
PyNormalizerTypeWrapper::Single(inner) => match inner.as_ref() {
|
||||
PyNormalizerTypeWrapper::Single(inner) => match *inner.as_ref().read().unwrap() {
|
||||
PyNormalizerWrapper::Wrapped(NormalizerWrapper::NFKC(_)) => {}
|
||||
_ => panic!("Expected NFKC"),
|
||||
},
|
||||
@ -547,7 +548,7 @@ mod test {
|
||||
let string = r#"{"type": "NFKC"}"#;
|
||||
let normalizer: PyNormalizer = serde_json::from_str(&string).unwrap();
|
||||
match normalizer.normalizer {
|
||||
PyNormalizerTypeWrapper::Single(inner) => match inner.as_ref() {
|
||||
PyNormalizerTypeWrapper::Single(inner) => match *inner.as_ref().read().unwrap() {
|
||||
PyNormalizerWrapper::Wrapped(NormalizerWrapper::NFKC(_)) => {}
|
||||
_ => panic!("Expected NFKC"),
|
||||
},
|
||||
@ -558,7 +559,7 @@ mod test {
|
||||
let normalizer: PyNormalizer = serde_json::from_str(&sequence_string).unwrap();
|
||||
|
||||
match normalizer.normalizer {
|
||||
PyNormalizerTypeWrapper::Single(inner) => match inner.as_ref() {
|
||||
PyNormalizerTypeWrapper::Single(inner) => match &*inner.as_ref().read().unwrap() {
|
||||
PyNormalizerWrapper::Wrapped(NormalizerWrapper::Sequence(sequence)) => {
|
||||
let normalizers = sequence.get_normalizers();
|
||||
assert_eq!(normalizers.len(), 1);
|
||||
@ -570,6 +571,6 @@ mod test {
|
||||
_ => panic!("Expected sequence"),
|
||||
},
|
||||
_ => panic!("Expected single"),
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user