Python - Fix Normalizer.normalize with PyNormalizedStringRefMut

This commit is contained in:
Anthony MOI
2021-02-03 11:20:57 -05:00
committed by Anthony MOI
parent 355315e8d3
commit db22cb6315
3 changed files with 45 additions and 6 deletions

View File

@ -8,7 +8,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed
- [#616]: Fix SentencePiece tokenizers conversion
- [#616]: Fix offsets produced by Precompiled Normalizer (used by tokenizers converted from SPM)
- [#617]: Fix offsets produced by Precompiled Normalizer (used by tokenizers converted from SPM)
- [#618]: Fix Normalizer.normalize with `PyNormalizedStringRefMut`
## [0.10.0]
@ -299,6 +300,7 @@ delimiter (Works like `.split(delimiter)`)
- Fix a bug that was causing crashes in Python 3.5
[#618]: https://github.com/huggingface/tokenizers/pull/618
[#617]: https://github.com/huggingface/tokenizers/pull/617
[#616]: https://github.com/huggingface/tokenizers/pull/616
[#590]: https://github.com/huggingface/tokenizers/pull/590

View File

@ -16,6 +16,29 @@ use tk::normalizers::{
use tk::{NormalizedString, Normalizer};
use tokenizers as tk;
/// Represents the different kind of NormalizedString we can receive from Python:
/// - Owned: Created in Python and owned by Python
/// - RefMut: A mutable reference to a NormalizedString owned by Rust
#[derive(FromPyObject)]
enum PyNormalizedStringMut<'p> {
Owned(PyRefMut<'p, PyNormalizedString>),
RefMut(PyNormalizedStringRefMut),
}
impl PyNormalizedStringMut<'_> {
/// Normalized the underlying `NormalizedString` using the provided normalizer
pub fn normalize_with<N>(&mut self, normalizer: &N) -> PyResult<()>
where
N: Normalizer,
{
match self {
PyNormalizedStringMut::Owned(ref mut n) => normalizer.normalize(&mut n.normalized),
PyNormalizedStringMut::RefMut(n) => n.map_as_mut(|n| normalizer.normalize(n))?,
}
.map_err(|e| exceptions::PyException::new_err(format!("{}", e)))
}
}
/// Base class for all normalizers
///
/// This class is not supposed to be instantiated directly. Instead, any implementation of a
@ -122,8 +145,8 @@ impl PyNormalizer {
/// The normalized string on which to apply this
/// :class:`~tokenizers.normalizers.Normalizer`
#[text_signature = "(self, normalized)"]
fn normalize(&self, normalized: &mut PyNormalizedString) -> PyResult<()> {
ToPyResult(self.normalizer.normalize(&mut normalized.normalized)).into()
fn normalize(&self, mut normalized: PyNormalizedStringMut) -> PyResult<()> {
normalized.normalize_with(&self.normalizer)
}
/// Normalize the given string
@ -457,7 +480,7 @@ impl PyReplace {
}
}
#[derive(Clone)]
#[derive(Debug, Clone)]
pub(crate) struct CustomNormalizer {
inner: PyObject,
}
@ -500,7 +523,7 @@ impl<'de> Deserialize<'de> for CustomNormalizer {
}
}
#[derive(Clone, Deserialize)]
#[derive(Debug, Clone, Deserialize)]
#[serde(untagged)]
pub(crate) enum PyNormalizerWrapper {
Custom(CustomNormalizer),
@ -519,7 +542,7 @@ impl Serialize for PyNormalizerWrapper {
}
}
#[derive(Clone, Deserialize)]
#[derive(Debug, Clone, Deserialize)]
#[serde(untagged)]
pub(crate) enum PyNormalizerTypeWrapper {
Sequence(Vec<Arc<RwLock<PyNormalizerWrapper>>>),

View File

@ -411,6 +411,20 @@ impl PyNormalizedStringRefMut {
pub fn destroyed_error() -> PyErr {
exceptions::PyException::new_err("Cannot use a NormalizedStringRefMut outside `normalize`")
}
/// Provides a way to access a reference to the underlying NormalizedString
pub fn map_as_ref<F: FnOnce(&NormalizedString) -> U, U>(&self, f: F) -> PyResult<U> {
self.inner
.map(f)
.ok_or_else(PyNormalizedStringRefMut::destroyed_error)
}
/// Provides a way to access a mutable reference to the underlying NormalizedString
pub fn map_as_mut<F: FnOnce(&mut NormalizedString) -> U, U>(&mut self, f: F) -> PyResult<U> {
self.inner
.map_mut(f)
.ok_or_else(PyNormalizedStringRefMut::destroyed_error)
}
}
#[pymethods]