mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-23 00:35:35 +00:00
Ensure serialization works in all expected ways.
This commit is contained in:
committed by
Anthony MOI
parent
aaf8e932b1
commit
16f75d9efc
@ -4,51 +4,57 @@ use pyo3::exceptions;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::*;
|
||||
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
use tk::normalizers::bert::BertNormalizer;
|
||||
use tk::normalizers::strip::Strip;
|
||||
use tk::normalizers::unicode::{NFC, NFD, NFKC, NFKD};
|
||||
use tk::normalizers::utils::{Lowercase, Sequence};
|
||||
use serde::ser::SerializeStruct;
|
||||
use serde::{Deserialize, Serialize, Serializer};
|
||||
use tk::normalizers::{BertNormalizer, Lowercase, NormalizerWrapper, Strip, NFC, NFD, NFKC, NFKD};
|
||||
use tk::{NormalizedString, Normalizer};
|
||||
use tokenizers as tk;
|
||||
|
||||
#[pyclass(dict, module = "tokenizers.normalizers", name=Normalizer)]
|
||||
#[derive(Clone)]
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
pub struct PyNormalizer {
|
||||
pub normalizer: Arc<dyn Normalizer>,
|
||||
#[serde(flatten)]
|
||||
pub(crate) normalizer: PyNormalizerWrapper,
|
||||
}
|
||||
|
||||
impl PyNormalizer {
|
||||
pub fn new(normalizer: Arc<dyn Normalizer>) -> Self {
|
||||
pub(crate) fn new(normalizer: PyNormalizerWrapper) -> Self {
|
||||
PyNormalizer { normalizer }
|
||||
}
|
||||
pub(crate) fn get_as_subtype(&self) -> PyResult<PyObject> {
|
||||
let base = self.clone();
|
||||
let gil = Python::acquire_gil();
|
||||
let py = gil.python();
|
||||
match self.normalizer {
|
||||
PyNormalizerWrapper::Sequence(_) => Py::new(py, (PySequence {}, base)).map(Into::into),
|
||||
PyNormalizerWrapper::Wrapped(ref inner) => match inner.as_ref() {
|
||||
NormalizerWrapper::Sequence(_) => {
|
||||
Py::new(py, (PySequence {}, base)).map(Into::into)
|
||||
}
|
||||
NormalizerWrapper::BertNormalizer(_) => {
|
||||
Py::new(py, (PyBertNormalizer {}, base)).map(Into::into)
|
||||
}
|
||||
NormalizerWrapper::StripNormalizer(_) => {
|
||||
Py::new(py, (PyBertNormalizer {}, base)).map(Into::into)
|
||||
}
|
||||
NormalizerWrapper::NFC(_) => Py::new(py, (PyNFC {}, base)).map(Into::into),
|
||||
NormalizerWrapper::NFD(_) => Py::new(py, (PyNFD {}, base)).map(Into::into),
|
||||
NormalizerWrapper::NFKC(_) => Py::new(py, (PyNFKC {}, base)).map(Into::into),
|
||||
NormalizerWrapper::NFKD(_) => Py::new(py, (PyNFKD {}, base)).map(Into::into),
|
||||
NormalizerWrapper::Lowercase(_) => {
|
||||
Py::new(py, (PyLowercase {}, base)).map(Into::into)
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[typetag::serde]
|
||||
impl Normalizer for PyNormalizer {
|
||||
fn normalize(&self, normalized: &mut NormalizedString) -> tk::Result<()> {
|
||||
self.normalizer.normalize(normalized)
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for PyNormalizer {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
self.normalizer.serialize(serializer)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for PyNormalizer {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
Ok(PyNormalizer::new(Arc::deserialize(deserializer)?))
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl PyNormalizer {
|
||||
fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
|
||||
@ -103,7 +109,7 @@ impl PyBertNormalizer {
|
||||
}
|
||||
let normalizer =
|
||||
BertNormalizer::new(clean_text, handle_chinese_chars, strip_accents, lowercase);
|
||||
Ok((PyBertNormalizer {}, PyNormalizer::new(Arc::new(normalizer))))
|
||||
Ok((PyBertNormalizer {}, normalizer.into()))
|
||||
}
|
||||
}
|
||||
|
||||
@ -113,7 +119,7 @@ pub struct PyNFD {}
|
||||
impl PyNFD {
|
||||
#[new]
|
||||
fn new() -> PyResult<(Self, PyNormalizer)> {
|
||||
Ok((PyNFD {}, PyNormalizer::new(Arc::new(NFD))))
|
||||
Ok((PyNFD {}, PyNormalizer::new(NFD.into())))
|
||||
}
|
||||
}
|
||||
|
||||
@ -123,7 +129,7 @@ pub struct PyNFKD {}
|
||||
impl PyNFKD {
|
||||
#[new]
|
||||
fn new() -> PyResult<(Self, PyNormalizer)> {
|
||||
Ok((PyNFKD {}, PyNormalizer::new(Arc::new(NFKD))))
|
||||
Ok((PyNFKD {}, NFKD.into()))
|
||||
}
|
||||
}
|
||||
|
||||
@ -133,7 +139,7 @@ pub struct PyNFC {}
|
||||
impl PyNFC {
|
||||
#[new]
|
||||
fn new() -> PyResult<(Self, PyNormalizer)> {
|
||||
Ok((PyNFC {}, PyNormalizer::new(Arc::new(NFC))))
|
||||
Ok((PyNFC {}, NFC.into()))
|
||||
}
|
||||
}
|
||||
|
||||
@ -143,7 +149,7 @@ pub struct PyNFKC {}
|
||||
impl PyNFKC {
|
||||
#[new]
|
||||
fn new() -> PyResult<(Self, PyNormalizer)> {
|
||||
Ok((PyNFKC {}, PyNormalizer::new(Arc::new(NFKC))))
|
||||
Ok((PyNFKC {}, NFKC.into()))
|
||||
}
|
||||
}
|
||||
|
||||
@ -153,19 +159,19 @@ pub struct PySequence {}
|
||||
impl PySequence {
|
||||
#[new]
|
||||
fn new(normalizers: &PyList) -> PyResult<(Self, PyNormalizer)> {
|
||||
let normalizers = normalizers
|
||||
.iter()
|
||||
.map(|n| {
|
||||
let normalizer: PyRef<PyNormalizer> = n.extract()?;
|
||||
let normalizer = PyNormalizer::new(normalizer.normalizer.clone());
|
||||
let boxed = Box::new(normalizer);
|
||||
Ok(boxed as Box<dyn Normalizer>)
|
||||
})
|
||||
.collect::<PyResult<_>>()?;
|
||||
|
||||
let mut sequence = Vec::with_capacity(normalizers.len());
|
||||
for n in normalizers.iter() {
|
||||
let normalizer: PyRef<PyNormalizer> = n.extract()?;
|
||||
match &normalizer.normalizer {
|
||||
PyNormalizerWrapper::Sequence(inner) => {
|
||||
sequence.extend(inner.iter().map(|i| i.clone()))
|
||||
}
|
||||
PyNormalizerWrapper::Wrapped(inner) => sequence.push(inner.clone()),
|
||||
}
|
||||
}
|
||||
Ok((
|
||||
PySequence {},
|
||||
PyNormalizer::new(Arc::new(Sequence::new(normalizers))),
|
||||
PyNormalizer::new(PyNormalizerWrapper::Sequence(sequence)),
|
||||
))
|
||||
}
|
||||
|
||||
@ -180,7 +186,7 @@ pub struct PyLowercase {}
|
||||
impl PyLowercase {
|
||||
#[new]
|
||||
fn new() -> PyResult<(Self, PyNormalizer)> {
|
||||
Ok((PyLowercase {}, PyNormalizer::new(Arc::new(Lowercase))))
|
||||
Ok((PyLowercase {}, Lowercase.into()))
|
||||
}
|
||||
}
|
||||
|
||||
@ -203,9 +209,114 @@ impl PyStrip {
|
||||
}
|
||||
}
|
||||
|
||||
Ok((
|
||||
PyStrip {},
|
||||
PyNormalizer::new(Arc::new(Strip::new(left, right))),
|
||||
))
|
||||
Ok((PyStrip {}, Strip::new(left, right).into()))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub(crate) enum PyNormalizerWrapper {
|
||||
Sequence(Vec<Arc<NormalizerWrapper>>),
|
||||
Wrapped(Arc<NormalizerWrapper>),
|
||||
}
|
||||
|
||||
impl Serialize for PyNormalizerWrapper {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<<S as Serializer>::Ok, <S as Serializer>::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
match self {
|
||||
PyNormalizerWrapper::Sequence(seq) => {
|
||||
let mut ser = serializer.serialize_struct("Sequence", 2)?;
|
||||
ser.serialize_field("type", "Sequence")?;
|
||||
ser.serialize_field("normalizers", seq)?;
|
||||
ser.end()
|
||||
}
|
||||
PyNormalizerWrapper::Wrapped(inner) => inner.serialize(serializer),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<I> From<I> for PyNormalizerWrapper
|
||||
where
|
||||
I: Into<NormalizerWrapper>,
|
||||
{
|
||||
fn from(norm: I) -> Self {
|
||||
PyNormalizerWrapper::Wrapped(Arc::new(norm.into()))
|
||||
}
|
||||
}
|
||||
|
||||
impl<I> From<I> for PyNormalizer
|
||||
where
|
||||
I: Into<NormalizerWrapper>,
|
||||
{
|
||||
fn from(norm: I) -> Self {
|
||||
PyNormalizer {
|
||||
normalizer: norm.into().into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Normalizer for PyNormalizerWrapper {
|
||||
fn normalize(&self, normalized: &mut NormalizedString) -> tk::Result<()> {
|
||||
match self {
|
||||
PyNormalizerWrapper::Wrapped(inner) => inner.normalize(normalized),
|
||||
PyNormalizerWrapper::Sequence(inner) => {
|
||||
inner.iter().map(|n| n.normalize(normalized)).collect()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use pyo3::{AsPyRef, Python};
|
||||
use tk::normalizers::unicode::{NFC, NFKC};
|
||||
use tk::normalizers::utils::Sequence;
|
||||
use tk::normalizers::NormalizerWrapper;
|
||||
|
||||
use crate::normalizers::{PyNormalizer, PyNormalizerWrapper};
|
||||
|
||||
#[test]
|
||||
fn get_subtype() {
|
||||
let py_norm = PyNormalizer::new(NFC.into());
|
||||
let py_nfc = py_norm.get_as_subtype().unwrap();
|
||||
let gil = Python::acquire_gil();
|
||||
assert_eq!(
|
||||
"tokenizers.normalizers.NFC",
|
||||
py_nfc.as_ref(gil.python()).get_type().name()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serialize() {
|
||||
let py_wrapped: PyNormalizerWrapper = NFKC.into();
|
||||
let py_ser = serde_json::to_string(&py_wrapped).unwrap();
|
||||
let rs_wrapped = NormalizerWrapper::NFKC(NFKC);
|
||||
let rs_ser = serde_json::to_string(&rs_wrapped).unwrap();
|
||||
assert_eq!(py_ser, rs_ser);
|
||||
let py_norm: PyNormalizer = serde_json::from_str(&rs_ser).unwrap();
|
||||
match py_norm.normalizer {
|
||||
PyNormalizerWrapper::Wrapped(nfc) => match nfc.as_ref() {
|
||||
NormalizerWrapper::NFKC(_) => {}
|
||||
_ => panic!("Expected NFKC"),
|
||||
},
|
||||
_ => panic!("Expected wrapped, not sequence."),
|
||||
}
|
||||
|
||||
let py_seq: PyNormalizerWrapper = Sequence::new(vec![NFC.into(), NFKC.into()]).into();
|
||||
let py_wrapper_ser = serde_json::to_string(&py_seq).unwrap();
|
||||
let rs_wrapped =
|
||||
NormalizerWrapper::Sequence(Sequence::new(vec![NFC.into(), NFKC.into()]).into());
|
||||
let rs_ser = serde_json::to_string(&rs_wrapped).unwrap();
|
||||
assert_eq!(py_wrapper_ser, rs_ser);
|
||||
|
||||
let py_seq = PyNormalizer::new(py_seq);
|
||||
let py_ser = serde_json::to_string(&py_seq).unwrap();
|
||||
assert_eq!(py_wrapper_ser, py_ser);
|
||||
|
||||
let rs_seq = Sequence::new(vec![NFC.into(), NFKC.into()]);
|
||||
let rs_ser = serde_json::to_string(&rs_seq).unwrap();
|
||||
assert_eq!(py_wrapper_ser, rs_ser);
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user