Python - Fix Normalizer.normalize with PyNormalizedStringRefMut

This commit is contained in:
Anthony MOI
2021-02-03 11:20:57 -05:00
committed by Anthony MOI
parent 355315e8d3
commit db22cb6315
3 changed files with 45 additions and 6 deletions

View File

@ -16,6 +16,29 @@ use tk::normalizers::{
use tk::{NormalizedString, Normalizer};
use tokenizers as tk;
/// Represents the different kind of NormalizedString we can receive from Python:
/// - Owned: Created in Python and owned by Python
/// - RefMut: A mutable reference to a NormalizedString owned by Rust
#[derive(FromPyObject)]
enum PyNormalizedStringMut<'p> {
Owned(PyRefMut<'p, PyNormalizedString>),
RefMut(PyNormalizedStringRefMut),
}
impl PyNormalizedStringMut<'_> {
/// Normalized the underlying `NormalizedString` using the provided normalizer
pub fn normalize_with<N>(&mut self, normalizer: &N) -> PyResult<()>
where
N: Normalizer,
{
match self {
PyNormalizedStringMut::Owned(ref mut n) => normalizer.normalize(&mut n.normalized),
PyNormalizedStringMut::RefMut(n) => n.map_as_mut(|n| normalizer.normalize(n))?,
}
.map_err(|e| exceptions::PyException::new_err(format!("{}", e)))
}
}
/// Base class for all normalizers
///
/// This class is not supposed to be instantiated directly. Instead, any implementation of a
@ -122,8 +145,8 @@ impl PyNormalizer {
/// The normalized string on which to apply this
/// :class:`~tokenizers.normalizers.Normalizer`
#[text_signature = "(self, normalized)"]
fn normalize(&self, normalized: &mut PyNormalizedString) -> PyResult<()> {
ToPyResult(self.normalizer.normalize(&mut normalized.normalized)).into()
fn normalize(&self, mut normalized: PyNormalizedStringMut) -> PyResult<()> {
normalized.normalize_with(&self.normalizer)
}
/// Normalize the given string
@ -457,7 +480,7 @@ impl PyReplace {
}
}
#[derive(Clone)]
#[derive(Debug, Clone)]
pub(crate) struct CustomNormalizer {
inner: PyObject,
}
@ -500,7 +523,7 @@ impl<'de> Deserialize<'de> for CustomNormalizer {
}
}
#[derive(Clone, Deserialize)]
#[derive(Debug, Clone, Deserialize)]
#[serde(untagged)]
pub(crate) enum PyNormalizerWrapper {
Custom(CustomNormalizer),
@ -519,7 +542,7 @@ impl Serialize for PyNormalizerWrapper {
}
}
#[derive(Clone, Deserialize)]
#[derive(Debug, Clone, Deserialize)]
#[serde(untagged)]
pub(crate) enum PyNormalizerTypeWrapper {
Sequence(Vec<Arc<RwLock<PyNormalizerWrapper>>>),