Creating normalizers.Prepend (To be used instead of Metaspace). (#1194)

* Creating `normalizers.Prepend` (To be used instead of `Metaspace`).

* Linting + stub.

* Fixing pickling/unpickling by setting a default.

* Black.
This commit is contained in:
Nicolas Patry
2023-03-24 00:33:31 +01:00
committed by GitHub
parent 250d46c676
commit d2c8190a0f
10 changed files with 191 additions and 5 deletions

View File

@ -9,8 +9,8 @@ use crate::utils::{PyNormalizedString, PyNormalizedStringRefMut, PyPattern};
use serde::ser::SerializeStruct;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use tk::normalizers::{
BertNormalizer, Lowercase, Nmt, NormalizerWrapper, Precompiled, Replace, Strip, StripAccents,
NFC, NFD, NFKC, NFKD,
BertNormalizer, Lowercase, Nmt, NormalizerWrapper, Precompiled, Prepend, Replace, Strip,
StripAccents, NFC, NFD, NFKC, NFKD,
};
use tk::{NormalizedString, Normalizer};
use tokenizers as tk;
@ -69,6 +69,7 @@ impl PyNormalizer {
NormalizerWrapper::StripNormalizer(_) => {
Py::new(py, (PyBertNormalizer {}, base))?.into_py(py)
}
NormalizerWrapper::Prepend(_) => Py::new(py, (PyPrepend {}, base))?.into_py(py),
NormalizerWrapper::StripAccents(_) => {
Py::new(py, (PyStripAccents {}, base))?.into_py(py)
}
@ -172,7 +173,8 @@ macro_rules! getter {
let super_ = $self.as_ref();
if let PyNormalizerTypeWrapper::Single(ref norm) = super_.normalizer {
let wrapper = norm.read().unwrap();
if let PyNormalizerWrapper::Wrapped(NormalizerWrapper::$variant(o)) = *wrapper {
if let PyNormalizerWrapper::Wrapped(NormalizerWrapper::$variant(o)) = (*wrapper).clone()
{
o.$name
} else {
unreachable!()
@ -413,6 +415,29 @@ impl PyStrip {
}
}
/// Prepend normalizer
#[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name = "Prepend")]
#[pyo3(text_signature = "(self, prepend)")]
pub struct PyPrepend {}
#[pymethods]
impl PyPrepend {
#[getter]
fn get_prepend(self_: PyRef<Self>) -> String {
getter!(self_, Prepend, prepend)
}
#[setter]
fn set_prepend(self_: PyRef<Self>, prepend: String) {
setter!(self_, Prepend, prepend, prepend)
}
#[new]
#[pyo3(signature = (prepend="".to_string()))]
fn new(prepend: String) -> (Self, PyNormalizer) {
(PyPrepend {}, Prepend::new(prepend).into())
}
}
/// StripAccents normalizer
#[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name = "StripAccents")]
#[pyo3(text_signature = "(self)")]
@ -624,6 +649,7 @@ pub fn normalizers(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<PyLowercase>()?;
m.add_class::<PyStrip>()?;
m.add_class::<PyStripAccents>()?;
m.add_class::<PyPrepend>()?;
m.add_class::<PyNmt>()?;
m.add_class::<PyPrecompiled>()?;
m.add_class::<PyReplace>()?;