mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Python - PyNormalizer & PyPreTokenizer use a RwLock
This commit is contained in:
@ -1,4 +1,4 @@
|
||||
use std::sync::Arc;
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
use pyo3::exceptions;
|
||||
use pyo3::prelude::*;
|
||||
@ -36,7 +36,7 @@ impl PyNormalizer {
|
||||
let py = gil.python();
|
||||
Ok(match self.normalizer {
|
||||
PyNormalizerTypeWrapper::Sequence(_) => Py::new(py, (PySequence {}, base))?.into_py(py),
|
||||
PyNormalizerTypeWrapper::Single(ref inner) => match inner.as_ref() {
|
||||
PyNormalizerTypeWrapper::Single(ref inner) => match &*inner.as_ref().read().unwrap() {
|
||||
PyNormalizerWrapper::Custom(_) => Py::new(py, base)?.into_py(py),
|
||||
PyNormalizerWrapper::Wrapped(ref inner) => match inner {
|
||||
NormalizerWrapper::Sequence(_) => {
|
||||
@ -421,8 +421,8 @@ impl Serialize for PyNormalizerWrapper {
|
||||
#[derive(Clone, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub(crate) enum PyNormalizerTypeWrapper {
|
||||
Sequence(Vec<Arc<PyNormalizerWrapper>>),
|
||||
Single(Arc<PyNormalizerWrapper>),
|
||||
Sequence(Vec<Arc<RwLock<PyNormalizerWrapper>>>),
|
||||
Single(Arc<RwLock<PyNormalizerWrapper>>),
|
||||
}
|
||||
|
||||
impl Serialize for PyNormalizerTypeWrapper {
|
||||
@ -456,7 +456,7 @@ where
|
||||
I: Into<PyNormalizerWrapper>,
|
||||
{
|
||||
fn from(norm: I) -> Self {
|
||||
PyNormalizerTypeWrapper::Single(Arc::new(norm.into()))
|
||||
PyNormalizerTypeWrapper::Single(Arc::new(RwLock::new(norm.into())))
|
||||
}
|
||||
}
|
||||
|
||||
@ -474,10 +474,11 @@ where
|
||||
impl Normalizer for PyNormalizerTypeWrapper {
|
||||
fn normalize(&self, normalized: &mut NormalizedString) -> tk::Result<()> {
|
||||
match self {
|
||||
PyNormalizerTypeWrapper::Single(inner) => inner.normalize(normalized),
|
||||
PyNormalizerTypeWrapper::Sequence(inner) => {
|
||||
inner.iter().map(|n| n.normalize(normalized)).collect()
|
||||
}
|
||||
PyNormalizerTypeWrapper::Single(inner) => inner.read().unwrap().normalize(normalized),
|
||||
PyNormalizerTypeWrapper::Sequence(inner) => inner
|
||||
.iter()
|
||||
.map(|n| n.read().unwrap().normalize(normalized))
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -520,7 +521,7 @@ mod test {
|
||||
assert_eq!(py_ser, rs_ser);
|
||||
let py_norm: PyNormalizer = serde_json::from_str(&rs_ser).unwrap();
|
||||
match py_norm.normalizer {
|
||||
PyNormalizerTypeWrapper::Single(inner) => match inner.as_ref() {
|
||||
PyNormalizerTypeWrapper::Single(inner) => match *inner.as_ref().read().unwrap() {
|
||||
PyNormalizerWrapper::Wrapped(NormalizerWrapper::NFKC(_)) => {}
|
||||
_ => panic!("Expected NFKC"),
|
||||
},
|
||||
@ -547,7 +548,7 @@ mod test {
|
||||
let string = r#"{"type": "NFKC"}"#;
|
||||
let normalizer: PyNormalizer = serde_json::from_str(&string).unwrap();
|
||||
match normalizer.normalizer {
|
||||
PyNormalizerTypeWrapper::Single(inner) => match inner.as_ref() {
|
||||
PyNormalizerTypeWrapper::Single(inner) => match *inner.as_ref().read().unwrap() {
|
||||
PyNormalizerWrapper::Wrapped(NormalizerWrapper::NFKC(_)) => {}
|
||||
_ => panic!("Expected NFKC"),
|
||||
},
|
||||
@ -558,7 +559,7 @@ mod test {
|
||||
let normalizer: PyNormalizer = serde_json::from_str(&sequence_string).unwrap();
|
||||
|
||||
match normalizer.normalizer {
|
||||
PyNormalizerTypeWrapper::Single(inner) => match inner.as_ref() {
|
||||
PyNormalizerTypeWrapper::Single(inner) => match &*inner.as_ref().read().unwrap() {
|
||||
PyNormalizerWrapper::Wrapped(NormalizerWrapper::Sequence(sequence)) => {
|
||||
let normalizers = sequence.get_normalizers();
|
||||
assert_eq!(normalizers.len(), 1);
|
||||
@ -570,6 +571,6 @@ mod test {
|
||||
_ => panic!("Expected sequence"),
|
||||
},
|
||||
_ => panic!("Expected single"),
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
use std::sync::Arc;
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
use pyo3::exceptions;
|
||||
use pyo3::prelude::*;
|
||||
@ -48,40 +48,46 @@ impl PyPreTokenizer {
|
||||
PyPreTokenizerTypeWrapper::Sequence(_) => {
|
||||
Py::new(py, (PySequence {}, base))?.into_py(py)
|
||||
}
|
||||
PyPreTokenizerTypeWrapper::Single(ref inner) => match inner.as_ref() {
|
||||
PyPreTokenizerWrapper::Custom(_) => Py::new(py, base)?.into_py(py),
|
||||
PyPreTokenizerWrapper::Wrapped(inner) => match inner {
|
||||
PreTokenizerWrapper::Whitespace(_) => {
|
||||
Py::new(py, (PyWhitespace {}, base))?.into_py(py)
|
||||
}
|
||||
PreTokenizerWrapper::Split(_) => Py::new(py, (PySplit {}, base))?.into_py(py),
|
||||
PreTokenizerWrapper::Punctuation(_) => {
|
||||
Py::new(py, (PyPunctuation {}, base))?.into_py(py)
|
||||
}
|
||||
PreTokenizerWrapper::Sequence(_) => {
|
||||
Py::new(py, (PySequence {}, base))?.into_py(py)
|
||||
}
|
||||
PreTokenizerWrapper::Metaspace(_) => {
|
||||
Py::new(py, (PyMetaspace {}, base))?.into_py(py)
|
||||
}
|
||||
PreTokenizerWrapper::Delimiter(_) => {
|
||||
Py::new(py, (PyCharDelimiterSplit {}, base))?.into_py(py)
|
||||
}
|
||||
PreTokenizerWrapper::WhitespaceSplit(_) => {
|
||||
Py::new(py, (PyWhitespaceSplit {}, base))?.into_py(py)
|
||||
}
|
||||
PreTokenizerWrapper::ByteLevel(_) => {
|
||||
Py::new(py, (PyByteLevel {}, base))?.into_py(py)
|
||||
}
|
||||
PreTokenizerWrapper::BertPreTokenizer(_) => {
|
||||
Py::new(py, (PyBertPreTokenizer {}, base))?.into_py(py)
|
||||
}
|
||||
PreTokenizerWrapper::Digits(_) => Py::new(py, (PyDigits {}, base))?.into_py(py),
|
||||
PreTokenizerWrapper::UnicodeScripts(_) => {
|
||||
Py::new(py, (PyUnicodeScripts {}, base))?.into_py(py)
|
||||
}
|
||||
},
|
||||
},
|
||||
PyPreTokenizerTypeWrapper::Single(ref inner) => {
|
||||
match &*inner.as_ref().read().unwrap() {
|
||||
PyPreTokenizerWrapper::Custom(_) => Py::new(py, base)?.into_py(py),
|
||||
PyPreTokenizerWrapper::Wrapped(inner) => match inner {
|
||||
PreTokenizerWrapper::Whitespace(_) => {
|
||||
Py::new(py, (PyWhitespace {}, base))?.into_py(py)
|
||||
}
|
||||
PreTokenizerWrapper::Split(_) => {
|
||||
Py::new(py, (PySplit {}, base))?.into_py(py)
|
||||
}
|
||||
PreTokenizerWrapper::Punctuation(_) => {
|
||||
Py::new(py, (PyPunctuation {}, base))?.into_py(py)
|
||||
}
|
||||
PreTokenizerWrapper::Sequence(_) => {
|
||||
Py::new(py, (PySequence {}, base))?.into_py(py)
|
||||
}
|
||||
PreTokenizerWrapper::Metaspace(_) => {
|
||||
Py::new(py, (PyMetaspace {}, base))?.into_py(py)
|
||||
}
|
||||
PreTokenizerWrapper::Delimiter(_) => {
|
||||
Py::new(py, (PyCharDelimiterSplit {}, base))?.into_py(py)
|
||||
}
|
||||
PreTokenizerWrapper::WhitespaceSplit(_) => {
|
||||
Py::new(py, (PyWhitespaceSplit {}, base))?.into_py(py)
|
||||
}
|
||||
PreTokenizerWrapper::ByteLevel(_) => {
|
||||
Py::new(py, (PyByteLevel {}, base))?.into_py(py)
|
||||
}
|
||||
PreTokenizerWrapper::BertPreTokenizer(_) => {
|
||||
Py::new(py, (PyBertPreTokenizer {}, base))?.into_py(py)
|
||||
}
|
||||
PreTokenizerWrapper::Digits(_) => {
|
||||
Py::new(py, (PyDigits {}, base))?.into_py(py)
|
||||
}
|
||||
PreTokenizerWrapper::UnicodeScripts(_) => {
|
||||
Py::new(py, (PyUnicodeScripts {}, base))?.into_py(py)
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -492,8 +498,8 @@ impl Serialize for PyPreTokenizerWrapper {
|
||||
#[derive(Clone, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub(crate) enum PyPreTokenizerTypeWrapper {
|
||||
Sequence(Vec<Arc<PyPreTokenizerWrapper>>),
|
||||
Single(Arc<PyPreTokenizerWrapper>),
|
||||
Sequence(Vec<Arc<RwLock<PyPreTokenizerWrapper>>>),
|
||||
Single(Arc<RwLock<PyPreTokenizerWrapper>>),
|
||||
}
|
||||
|
||||
impl Serialize for PyPreTokenizerTypeWrapper {
|
||||
@ -527,7 +533,7 @@ where
|
||||
I: Into<PyPreTokenizerWrapper>,
|
||||
{
|
||||
fn from(pretok: I) -> Self {
|
||||
PyPreTokenizerTypeWrapper::Single(Arc::new(pretok.into()))
|
||||
PyPreTokenizerTypeWrapper::Single(Arc::new(RwLock::new(pretok.into())))
|
||||
}
|
||||
}
|
||||
|
||||
@ -545,10 +551,11 @@ where
|
||||
impl PreTokenizer for PyPreTokenizerTypeWrapper {
|
||||
fn pre_tokenize(&self, pretok: &mut PreTokenizedString) -> tk::Result<()> {
|
||||
match self {
|
||||
PyPreTokenizerTypeWrapper::Single(inner) => inner.pre_tokenize(pretok),
|
||||
PyPreTokenizerTypeWrapper::Sequence(inner) => {
|
||||
inner.iter().map(|n| n.pre_tokenize(pretok)).collect()
|
||||
}
|
||||
PyPreTokenizerTypeWrapper::Single(inner) => inner.read().unwrap().pre_tokenize(pretok),
|
||||
PyPreTokenizerTypeWrapper::Sequence(inner) => inner
|
||||
.iter()
|
||||
.map(|n| n.read().unwrap().pre_tokenize(pretok))
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -593,7 +600,7 @@ mod test {
|
||||
assert_eq!(py_ser, rs_ser);
|
||||
let py_pretok: PyPreTokenizer = serde_json::from_str(&rs_ser).unwrap();
|
||||
match py_pretok.pretok {
|
||||
PyPreTokenizerTypeWrapper::Single(inner) => match inner.as_ref() {
|
||||
PyPreTokenizerTypeWrapper::Single(inner) => match *inner.as_ref().read().unwrap() {
|
||||
PyPreTokenizerWrapper::Wrapped(PreTokenizerWrapper::Whitespace(_)) => {}
|
||||
_ => panic!("Expected Whitespace"),
|
||||
},
|
||||
|
@ -1181,8 +1181,8 @@ mod test {
|
||||
fn serialize() {
|
||||
let mut tokenizer = Tokenizer::new(PyModel::from(BPE::default()));
|
||||
tokenizer.with_normalizer(PyNormalizer::new(PyNormalizerTypeWrapper::Sequence(vec![
|
||||
Arc::new(NFKC.into()),
|
||||
Arc::new(Lowercase.into()),
|
||||
Arc::new(RwLock::new(NFKC.into())),
|
||||
Arc::new(RwLock::new(Lowercase.into())),
|
||||
])));
|
||||
|
||||
let tmp = NamedTempFile::new().unwrap().into_temp_path();
|
||||
|
Reference in New Issue
Block a user