Python - PyNormalizer & PyPreTokenizer use a RwLock

This commit is contained in:
Anthony MOI
2020-11-16 11:37:28 -05:00
committed by Anthony MOI
parent 76d3b2128b
commit c22cfc31f9
3 changed files with 66 additions and 58 deletions

View File

@ -1,4 +1,4 @@
use std::sync::Arc;
use std::sync::{Arc, RwLock};
use pyo3::exceptions;
use pyo3::prelude::*;
@ -36,7 +36,7 @@ impl PyNormalizer {
let py = gil.python();
Ok(match self.normalizer {
PyNormalizerTypeWrapper::Sequence(_) => Py::new(py, (PySequence {}, base))?.into_py(py),
PyNormalizerTypeWrapper::Single(ref inner) => match inner.as_ref() {
PyNormalizerTypeWrapper::Single(ref inner) => match &*inner.as_ref().read().unwrap() {
PyNormalizerWrapper::Custom(_) => Py::new(py, base)?.into_py(py),
PyNormalizerWrapper::Wrapped(ref inner) => match inner {
NormalizerWrapper::Sequence(_) => {
@ -421,8 +421,8 @@ impl Serialize for PyNormalizerWrapper {
#[derive(Clone, Deserialize)]
#[serde(untagged)]
pub(crate) enum PyNormalizerTypeWrapper {
Sequence(Vec<Arc<PyNormalizerWrapper>>),
Single(Arc<PyNormalizerWrapper>),
Sequence(Vec<Arc<RwLock<PyNormalizerWrapper>>>),
Single(Arc<RwLock<PyNormalizerWrapper>>),
}
impl Serialize for PyNormalizerTypeWrapper {
@ -456,7 +456,7 @@ where
I: Into<PyNormalizerWrapper>,
{
fn from(norm: I) -> Self {
PyNormalizerTypeWrapper::Single(Arc::new(norm.into()))
PyNormalizerTypeWrapper::Single(Arc::new(RwLock::new(norm.into())))
}
}
@ -474,10 +474,11 @@ where
impl Normalizer for PyNormalizerTypeWrapper {
fn normalize(&self, normalized: &mut NormalizedString) -> tk::Result<()> {
match self {
PyNormalizerTypeWrapper::Single(inner) => inner.normalize(normalized),
PyNormalizerTypeWrapper::Sequence(inner) => {
inner.iter().map(|n| n.normalize(normalized)).collect()
}
PyNormalizerTypeWrapper::Single(inner) => inner.read().unwrap().normalize(normalized),
PyNormalizerTypeWrapper::Sequence(inner) => inner
.iter()
.map(|n| n.read().unwrap().normalize(normalized))
.collect(),
}
}
}
@ -520,7 +521,7 @@ mod test {
assert_eq!(py_ser, rs_ser);
let py_norm: PyNormalizer = serde_json::from_str(&rs_ser).unwrap();
match py_norm.normalizer {
PyNormalizerTypeWrapper::Single(inner) => match inner.as_ref() {
PyNormalizerTypeWrapper::Single(inner) => match *inner.as_ref().read().unwrap() {
PyNormalizerWrapper::Wrapped(NormalizerWrapper::NFKC(_)) => {}
_ => panic!("Expected NFKC"),
},
@ -547,7 +548,7 @@ mod test {
let string = r#"{"type": "NFKC"}"#;
let normalizer: PyNormalizer = serde_json::from_str(&string).unwrap();
match normalizer.normalizer {
PyNormalizerTypeWrapper::Single(inner) => match inner.as_ref() {
PyNormalizerTypeWrapper::Single(inner) => match *inner.as_ref().read().unwrap() {
PyNormalizerWrapper::Wrapped(NormalizerWrapper::NFKC(_)) => {}
_ => panic!("Expected NFKC"),
},
@ -558,7 +559,7 @@ mod test {
let normalizer: PyNormalizer = serde_json::from_str(&sequence_string).unwrap();
match normalizer.normalizer {
PyNormalizerTypeWrapper::Single(inner) => match inner.as_ref() {
PyNormalizerTypeWrapper::Single(inner) => match &*inner.as_ref().read().unwrap() {
PyNormalizerWrapper::Wrapped(NormalizerWrapper::Sequence(sequence)) => {
let normalizers = sequence.get_normalizers();
assert_eq!(normalizers.len(), 1);
@ -570,6 +571,6 @@ mod test {
_ => panic!("Expected sequence"),
},
_ => panic!("Expected single"),
}
};
}
}

View File

@ -1,4 +1,4 @@
use std::sync::Arc;
use std::sync::{Arc, RwLock};
use pyo3::exceptions;
use pyo3::prelude::*;
@ -48,40 +48,46 @@ impl PyPreTokenizer {
PyPreTokenizerTypeWrapper::Sequence(_) => {
Py::new(py, (PySequence {}, base))?.into_py(py)
}
PyPreTokenizerTypeWrapper::Single(ref inner) => match inner.as_ref() {
PyPreTokenizerWrapper::Custom(_) => Py::new(py, base)?.into_py(py),
PyPreTokenizerWrapper::Wrapped(inner) => match inner {
PreTokenizerWrapper::Whitespace(_) => {
Py::new(py, (PyWhitespace {}, base))?.into_py(py)
}
PreTokenizerWrapper::Split(_) => Py::new(py, (PySplit {}, base))?.into_py(py),
PreTokenizerWrapper::Punctuation(_) => {
Py::new(py, (PyPunctuation {}, base))?.into_py(py)
}
PreTokenizerWrapper::Sequence(_) => {
Py::new(py, (PySequence {}, base))?.into_py(py)
}
PreTokenizerWrapper::Metaspace(_) => {
Py::new(py, (PyMetaspace {}, base))?.into_py(py)
}
PreTokenizerWrapper::Delimiter(_) => {
Py::new(py, (PyCharDelimiterSplit {}, base))?.into_py(py)
}
PreTokenizerWrapper::WhitespaceSplit(_) => {
Py::new(py, (PyWhitespaceSplit {}, base))?.into_py(py)
}
PreTokenizerWrapper::ByteLevel(_) => {
Py::new(py, (PyByteLevel {}, base))?.into_py(py)
}
PreTokenizerWrapper::BertPreTokenizer(_) => {
Py::new(py, (PyBertPreTokenizer {}, base))?.into_py(py)
}
PreTokenizerWrapper::Digits(_) => Py::new(py, (PyDigits {}, base))?.into_py(py),
PreTokenizerWrapper::UnicodeScripts(_) => {
Py::new(py, (PyUnicodeScripts {}, base))?.into_py(py)
}
},
},
PyPreTokenizerTypeWrapper::Single(ref inner) => {
match &*inner.as_ref().read().unwrap() {
PyPreTokenizerWrapper::Custom(_) => Py::new(py, base)?.into_py(py),
PyPreTokenizerWrapper::Wrapped(inner) => match inner {
PreTokenizerWrapper::Whitespace(_) => {
Py::new(py, (PyWhitespace {}, base))?.into_py(py)
}
PreTokenizerWrapper::Split(_) => {
Py::new(py, (PySplit {}, base))?.into_py(py)
}
PreTokenizerWrapper::Punctuation(_) => {
Py::new(py, (PyPunctuation {}, base))?.into_py(py)
}
PreTokenizerWrapper::Sequence(_) => {
Py::new(py, (PySequence {}, base))?.into_py(py)
}
PreTokenizerWrapper::Metaspace(_) => {
Py::new(py, (PyMetaspace {}, base))?.into_py(py)
}
PreTokenizerWrapper::Delimiter(_) => {
Py::new(py, (PyCharDelimiterSplit {}, base))?.into_py(py)
}
PreTokenizerWrapper::WhitespaceSplit(_) => {
Py::new(py, (PyWhitespaceSplit {}, base))?.into_py(py)
}
PreTokenizerWrapper::ByteLevel(_) => {
Py::new(py, (PyByteLevel {}, base))?.into_py(py)
}
PreTokenizerWrapper::BertPreTokenizer(_) => {
Py::new(py, (PyBertPreTokenizer {}, base))?.into_py(py)
}
PreTokenizerWrapper::Digits(_) => {
Py::new(py, (PyDigits {}, base))?.into_py(py)
}
PreTokenizerWrapper::UnicodeScripts(_) => {
Py::new(py, (PyUnicodeScripts {}, base))?.into_py(py)
}
},
}
}
})
}
}
@ -492,8 +498,8 @@ impl Serialize for PyPreTokenizerWrapper {
#[derive(Clone, Deserialize)]
#[serde(untagged)]
pub(crate) enum PyPreTokenizerTypeWrapper {
Sequence(Vec<Arc<PyPreTokenizerWrapper>>),
Single(Arc<PyPreTokenizerWrapper>),
Sequence(Vec<Arc<RwLock<PyPreTokenizerWrapper>>>),
Single(Arc<RwLock<PyPreTokenizerWrapper>>),
}
impl Serialize for PyPreTokenizerTypeWrapper {
@ -527,7 +533,7 @@ where
I: Into<PyPreTokenizerWrapper>,
{
fn from(pretok: I) -> Self {
PyPreTokenizerTypeWrapper::Single(Arc::new(pretok.into()))
PyPreTokenizerTypeWrapper::Single(Arc::new(RwLock::new(pretok.into())))
}
}
@ -545,10 +551,11 @@ where
impl PreTokenizer for PyPreTokenizerTypeWrapper {
fn pre_tokenize(&self, pretok: &mut PreTokenizedString) -> tk::Result<()> {
match self {
PyPreTokenizerTypeWrapper::Single(inner) => inner.pre_tokenize(pretok),
PyPreTokenizerTypeWrapper::Sequence(inner) => {
inner.iter().map(|n| n.pre_tokenize(pretok)).collect()
}
PyPreTokenizerTypeWrapper::Single(inner) => inner.read().unwrap().pre_tokenize(pretok),
PyPreTokenizerTypeWrapper::Sequence(inner) => inner
.iter()
.map(|n| n.read().unwrap().pre_tokenize(pretok))
.collect(),
}
}
}
@ -593,7 +600,7 @@ mod test {
assert_eq!(py_ser, rs_ser);
let py_pretok: PyPreTokenizer = serde_json::from_str(&rs_ser).unwrap();
match py_pretok.pretok {
PyPreTokenizerTypeWrapper::Single(inner) => match inner.as_ref() {
PyPreTokenizerTypeWrapper::Single(inner) => match *inner.as_ref().read().unwrap() {
PyPreTokenizerWrapper::Wrapped(PreTokenizerWrapper::Whitespace(_)) => {}
_ => panic!("Expected Whitespace"),
},

View File

@ -1181,8 +1181,8 @@ mod test {
fn serialize() {
let mut tokenizer = Tokenizer::new(PyModel::from(BPE::default()));
tokenizer.with_normalizer(PyNormalizer::new(PyNormalizerTypeWrapper::Sequence(vec![
Arc::new(NFKC.into()),
Arc::new(Lowercase.into()),
Arc::new(RwLock::new(NFKC.into())),
Arc::new(RwLock::new(Lowercase.into())),
])));
let tmp = NamedTempFile::new().unwrap().into_temp_path();