diff --git a/bindings/python/src/normalizers.rs b/bindings/python/src/normalizers.rs index de9ef34c..9d0b9fbf 100644 --- a/bindings/python/src/normalizers.rs +++ b/bindings/python/src/normalizers.rs @@ -1,4 +1,4 @@ -use std::sync::Arc; +use std::sync::{Arc, RwLock}; use pyo3::exceptions; use pyo3::prelude::*; @@ -36,7 +36,7 @@ impl PyNormalizer { let py = gil.python(); Ok(match self.normalizer { PyNormalizerTypeWrapper::Sequence(_) => Py::new(py, (PySequence {}, base))?.into_py(py), - PyNormalizerTypeWrapper::Single(ref inner) => match inner.as_ref() { + PyNormalizerTypeWrapper::Single(ref inner) => match &*inner.as_ref().read().unwrap() { PyNormalizerWrapper::Custom(_) => Py::new(py, base)?.into_py(py), PyNormalizerWrapper::Wrapped(ref inner) => match inner { NormalizerWrapper::Sequence(_) => { @@ -421,8 +421,8 @@ impl Serialize for PyNormalizerWrapper { #[derive(Clone, Deserialize)] #[serde(untagged)] pub(crate) enum PyNormalizerTypeWrapper { - Sequence(Vec>), - Single(Arc), + Sequence(Vec>>), + Single(Arc>), } impl Serialize for PyNormalizerTypeWrapper { @@ -456,7 +456,7 @@ where I: Into, { fn from(norm: I) -> Self { - PyNormalizerTypeWrapper::Single(Arc::new(norm.into())) + PyNormalizerTypeWrapper::Single(Arc::new(RwLock::new(norm.into()))) } } @@ -474,10 +474,11 @@ where impl Normalizer for PyNormalizerTypeWrapper { fn normalize(&self, normalized: &mut NormalizedString) -> tk::Result<()> { match self { - PyNormalizerTypeWrapper::Single(inner) => inner.normalize(normalized), - PyNormalizerTypeWrapper::Sequence(inner) => { - inner.iter().map(|n| n.normalize(normalized)).collect() - } + PyNormalizerTypeWrapper::Single(inner) => inner.read().unwrap().normalize(normalized), + PyNormalizerTypeWrapper::Sequence(inner) => inner + .iter() + .map(|n| n.read().unwrap().normalize(normalized)) + .collect(), } } } @@ -520,7 +521,7 @@ mod test { assert_eq!(py_ser, rs_ser); let py_norm: PyNormalizer = serde_json::from_str(&rs_ser).unwrap(); match py_norm.normalizer { - PyNormalizerTypeWrapper::Single(inner) => match inner.as_ref() { + PyNormalizerTypeWrapper::Single(inner) => match *inner.as_ref().read().unwrap() { PyNormalizerWrapper::Wrapped(NormalizerWrapper::NFKC(_)) => {} _ => panic!("Expected NFKC"), }, @@ -547,7 +548,7 @@ mod test { let string = r#"{"type": "NFKC"}"#; let normalizer: PyNormalizer = serde_json::from_str(&string).unwrap(); match normalizer.normalizer { - PyNormalizerTypeWrapper::Single(inner) => match inner.as_ref() { + PyNormalizerTypeWrapper::Single(inner) => match *inner.as_ref().read().unwrap() { PyNormalizerWrapper::Wrapped(NormalizerWrapper::NFKC(_)) => {} _ => panic!("Expected NFKC"), }, @@ -558,7 +559,7 @@ mod test { let normalizer: PyNormalizer = serde_json::from_str(&sequence_string).unwrap(); match normalizer.normalizer { - PyNormalizerTypeWrapper::Single(inner) => match inner.as_ref() { + PyNormalizerTypeWrapper::Single(inner) => match &*inner.as_ref().read().unwrap() { PyNormalizerWrapper::Wrapped(NormalizerWrapper::Sequence(sequence)) => { let normalizers = sequence.get_normalizers(); assert_eq!(normalizers.len(), 1); @@ -570,6 +571,6 @@ mod test { _ => panic!("Expected sequence"), }, _ => panic!("Expected single"), - } + }; } } diff --git a/bindings/python/src/pre_tokenizers.rs b/bindings/python/src/pre_tokenizers.rs index c118871c..b067eaa4 100644 --- a/bindings/python/src/pre_tokenizers.rs +++ b/bindings/python/src/pre_tokenizers.rs @@ -1,4 +1,4 @@ -use std::sync::Arc; +use std::sync::{Arc, RwLock}; use pyo3::exceptions; use pyo3::prelude::*; @@ -48,40 +48,46 @@ impl PyPreTokenizer { PyPreTokenizerTypeWrapper::Sequence(_) => { Py::new(py, (PySequence {}, base))?.into_py(py) } - PyPreTokenizerTypeWrapper::Single(ref inner) => match inner.as_ref() { - PyPreTokenizerWrapper::Custom(_) => Py::new(py, base)?.into_py(py), - PyPreTokenizerWrapper::Wrapped(inner) => match inner { - PreTokenizerWrapper::Whitespace(_) => { - Py::new(py, (PyWhitespace {}, base))?.into_py(py) - } - PreTokenizerWrapper::Split(_) => Py::new(py, (PySplit {}, base))?.into_py(py), - PreTokenizerWrapper::Punctuation(_) => { - Py::new(py, (PyPunctuation {}, base))?.into_py(py) - } - PreTokenizerWrapper::Sequence(_) => { - Py::new(py, (PySequence {}, base))?.into_py(py) - } - PreTokenizerWrapper::Metaspace(_) => { - Py::new(py, (PyMetaspace {}, base))?.into_py(py) - } - PreTokenizerWrapper::Delimiter(_) => { - Py::new(py, (PyCharDelimiterSplit {}, base))?.into_py(py) - } - PreTokenizerWrapper::WhitespaceSplit(_) => { - Py::new(py, (PyWhitespaceSplit {}, base))?.into_py(py) - } - PreTokenizerWrapper::ByteLevel(_) => { - Py::new(py, (PyByteLevel {}, base))?.into_py(py) - } - PreTokenizerWrapper::BertPreTokenizer(_) => { - Py::new(py, (PyBertPreTokenizer {}, base))?.into_py(py) - } - PreTokenizerWrapper::Digits(_) => Py::new(py, (PyDigits {}, base))?.into_py(py), - PreTokenizerWrapper::UnicodeScripts(_) => { - Py::new(py, (PyUnicodeScripts {}, base))?.into_py(py) - } - }, - }, + PyPreTokenizerTypeWrapper::Single(ref inner) => { + match &*inner.as_ref().read().unwrap() { + PyPreTokenizerWrapper::Custom(_) => Py::new(py, base)?.into_py(py), + PyPreTokenizerWrapper::Wrapped(inner) => match inner { + PreTokenizerWrapper::Whitespace(_) => { + Py::new(py, (PyWhitespace {}, base))?.into_py(py) + } + PreTokenizerWrapper::Split(_) => { + Py::new(py, (PySplit {}, base))?.into_py(py) + } + PreTokenizerWrapper::Punctuation(_) => { + Py::new(py, (PyPunctuation {}, base))?.into_py(py) + } + PreTokenizerWrapper::Sequence(_) => { + Py::new(py, (PySequence {}, base))?.into_py(py) + } + PreTokenizerWrapper::Metaspace(_) => { + Py::new(py, (PyMetaspace {}, base))?.into_py(py) + } + PreTokenizerWrapper::Delimiter(_) => { + Py::new(py, (PyCharDelimiterSplit {}, base))?.into_py(py) + } + PreTokenizerWrapper::WhitespaceSplit(_) => { + Py::new(py, (PyWhitespaceSplit {}, base))?.into_py(py) + } + PreTokenizerWrapper::ByteLevel(_) => { + Py::new(py, (PyByteLevel {}, base))?.into_py(py) + } + PreTokenizerWrapper::BertPreTokenizer(_) => { + Py::new(py, (PyBertPreTokenizer {}, base))?.into_py(py) + } + PreTokenizerWrapper::Digits(_) => { + Py::new(py, (PyDigits {}, base))?.into_py(py) + } + PreTokenizerWrapper::UnicodeScripts(_) => { + Py::new(py, (PyUnicodeScripts {}, base))?.into_py(py) + } + }, + } + } }) } } @@ -492,8 +498,8 @@ impl Serialize for PyPreTokenizerWrapper { #[derive(Clone, Deserialize)] #[serde(untagged)] pub(crate) enum PyPreTokenizerTypeWrapper { - Sequence(Vec>), - Single(Arc), + Sequence(Vec>>), + Single(Arc>), } impl Serialize for PyPreTokenizerTypeWrapper { @@ -527,7 +533,7 @@ where I: Into, { fn from(pretok: I) -> Self { - PyPreTokenizerTypeWrapper::Single(Arc::new(pretok.into())) + PyPreTokenizerTypeWrapper::Single(Arc::new(RwLock::new(pretok.into()))) } } @@ -545,10 +551,11 @@ where impl PreTokenizer for PyPreTokenizerTypeWrapper { fn pre_tokenize(&self, pretok: &mut PreTokenizedString) -> tk::Result<()> { match self { - PyPreTokenizerTypeWrapper::Single(inner) => inner.pre_tokenize(pretok), - PyPreTokenizerTypeWrapper::Sequence(inner) => { - inner.iter().map(|n| n.pre_tokenize(pretok)).collect() - } + PyPreTokenizerTypeWrapper::Single(inner) => inner.read().unwrap().pre_tokenize(pretok), + PyPreTokenizerTypeWrapper::Sequence(inner) => inner + .iter() + .map(|n| n.read().unwrap().pre_tokenize(pretok)) + .collect(), } } } @@ -593,7 +600,7 @@ mod test { assert_eq!(py_ser, rs_ser); let py_pretok: PyPreTokenizer = serde_json::from_str(&rs_ser).unwrap(); match py_pretok.pretok { - PyPreTokenizerTypeWrapper::Single(inner) => match inner.as_ref() { + PyPreTokenizerTypeWrapper::Single(inner) => match *inner.as_ref().read().unwrap() { PyPreTokenizerWrapper::Wrapped(PreTokenizerWrapper::Whitespace(_)) => {} _ => panic!("Expected Whitespace"), }, diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index 89237e17..067dad1b 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -1181,8 +1181,8 @@ mod test { fn serialize() { let mut tokenizer = Tokenizer::new(PyModel::from(BPE::default())); tokenizer.with_normalizer(PyNormalizer::new(PyNormalizerTypeWrapper::Sequence(vec![ - Arc::new(NFKC.into()), - Arc::new(Lowercase.into()), + Arc::new(RwLock::new(NFKC.into())), + Arc::new(RwLock::new(Lowercase.into())), ]))); let tmp = NamedTempFile::new().unwrap().into_temp_path();