From c45aebd1029acfbe9e5dfe64e8b8441d9fae727a Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Tue, 28 Jan 2025 14:58:35 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=A8=20Support=20updating=20template=20?= =?UTF-8?q?processors=20(#1652)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * current updates * simplify * set_item works, but `tokenizer._tokenizer.post_processor[1].single = ["$0", ""]` does not ! * fix: `normalizers` deserialization and other refactoring * fix: `pre_tokenizer` deserialization * feat: add `__len__` implementation for `normalizer::PySequence` * feat: add `__setitem__` impl for `normalizers::PySequence` * feat: add `__setitem__` impl to `pre_tokenizer::PySequence` * feat: add `__setitem__` impl to `post_processor::PySequence` * test: add normalizer sequence setter check * refactor: allow unused `processors::setter` macro * test: add `__setitem__` test for processors & pretok * refactor: `unwrap` -> `PyException::new_err()?` * refactor: fmt * refactor: remove unnecessary `pub` * feat(bindings): add missing getters & setters for pretoks * feat(bindings): add missing getters & setters for processors * refactor(bindings): rewrite RwLock poison error msg * refactor: remove debug print * feat(bindings): add description as to why custom deser is needed * feat: make post proc sequence elements mutable * fix(binding): serialization --------- Co-authored-by: Luc Georges --- .github/workflows/python.yml | 1 - bindings/python/Cargo.toml | 2 +- bindings/python/pyproject.toml | 35 +- bindings/python/src/normalizers.rs | 146 ++++-- bindings/python/src/pre_tokenizers.rs | 190 ++++++- bindings/python/src/processors.rs | 485 +++++++++++++++--- bindings/python/src/utils/normalization.rs | 6 + .../python/tests/bindings/test_normalizers.py | 66 ++- .../tests/bindings/test_pre_tokenizers.py | 69 ++- .../python/tests/bindings/test_processors.py | 15 + tokenizers/src/normalizers/mod.rs | 6 +- tokenizers/src/normalizers/replace.rs | 2 +- tokenizers/src/normalizers/utils.rs | 17 +- tokenizers/src/pre_tokenizers/metaspace.rs | 6 + tokenizers/src/pre_tokenizers/punctuation.rs | 2 +- tokenizers/src/pre_tokenizers/sequence.rs | 17 +- tokenizers/src/pre_tokenizers/split.rs | 8 +- tokenizers/src/processors/bert.rs | 12 +- tokenizers/src/processors/roberta.rs | 16 +- tokenizers/src/processors/sequence.rs | 33 ++ tokenizers/src/processors/template.rs | 54 +- tokenizers/src/tokenizer/normalizer.rs | 6 + 22 files changed, 1013 insertions(+), 181 deletions(-) diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index 64f5670a..367ded9e 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -55,7 +55,6 @@ jobs: steps: - name: Checkout repository uses: actions/checkout@v4 - - name: Install Rust uses: actions-rs/toolchain@v1 diff --git a/bindings/python/Cargo.toml b/bindings/python/Cargo.toml index e987716b..98bf2d69 100644 --- a/bindings/python/Cargo.toml +++ b/bindings/python/Cargo.toml @@ -14,7 +14,7 @@ serde = { version = "1.0", features = ["rc", "derive"] } serde_json = "1.0" libc = "0.2" env_logger = "0.11" -pyo3 = { version = "0.23", features = ["abi3", "abi3-py39"] } +pyo3 = { version = "0.23", features = ["abi3", "abi3-py39", "py-clone"] } numpy = "0.23" ndarray = "0.16" itertools = "0.12" diff --git a/bindings/python/pyproject.toml b/bindings/python/pyproject.toml index 234765f6..50a340ac 100644 --- a/bindings/python/pyproject.toml +++ b/bindings/python/pyproject.toml @@ -2,8 +2,8 @@ name = 'tokenizers' requires-python = '>=3.9' authors = [ - {name = 'Nicolas Patry', email = 'patry.nicolas@protonmail.com'}, - {name = 'Anthony Moi', email = 'anthony@huggingface.co'} + { name = 'Nicolas Patry', email = 'patry.nicolas@protonmail.com' }, + { name = 'Anthony Moi', email = 'anthony@huggingface.co' }, ] classifiers = [ "Development Status :: 5 - Production/Stable", @@ -21,12 +21,7 @@ classifiers = [ "Topic :: Scientific/Engineering :: Artificial Intelligence", ] keywords = ["NLP", "tokenizer", "BPE", "transformer", "deep learning"] -dynamic = [ - 'description', - 'license', - 'readme', - 'version', -] +dynamic = ['description', 'license', 'readme', 'version'] dependencies = ["huggingface_hub>=0.16.4,<1.0"] [project.urls] @@ -58,16 +53,16 @@ target-version = ['py35'] line-length = 119 target-version = "py311" lint.ignore = [ - # a == None in tests vs is None. - "E711", - # a == False in tests vs is False. - "E712", - # try.. import except.. pattern without using the lib. - "F401", - # Raw type equality is required in asserts - "E721", - # Import order - "E402", - # Fixtures unused import - "F811", + # a == None in tests vs is None. + "E711", + # a == False in tests vs is False. + "E712", + # try.. import except.. pattern without using the lib. + "F401", + # Raw type equality is required in asserts + "E721", + # Import order + "E402", + # Fixtures unused import + "F811", ] diff --git a/bindings/python/src/normalizers.rs b/bindings/python/src/normalizers.rs index 0d35cb1b..3cd59a3c 100644 --- a/bindings/python/src/normalizers.rs +++ b/bindings/python/src/normalizers.rs @@ -1,3 +1,4 @@ +use pyo3::exceptions::PyException; use pyo3::types::*; use pyo3::{exceptions, prelude::*}; use std::sync::{Arc, RwLock}; @@ -41,7 +42,7 @@ impl PyNormalizedStringMut<'_> { /// This class is not supposed to be instantiated directly. Instead, any implementation of a /// Normalizer will return an instance of this class when instantiated. #[pyclass(dict, module = "tokenizers.normalizers", name = "Normalizer", subclass)] -#[derive(Clone, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] #[serde(transparent)] pub struct PyNormalizer { pub(crate) normalizer: PyNormalizerTypeWrapper, @@ -58,7 +59,11 @@ impl PyNormalizer { .into_pyobject(py)? .into_any() .into(), - PyNormalizerTypeWrapper::Single(ref inner) => match &*inner.as_ref().read().unwrap() { + PyNormalizerTypeWrapper::Single(ref inner) => match &*inner + .as_ref() + .read() + .map_err(|_| PyException::new_err("RwLock synchronisation primitive is poisoned, cannot get subtype of PyNormalizer"))? + { PyNormalizerWrapper::Custom(_) => { Py::new(py, base)?.into_pyobject(py)?.into_any().into() } @@ -218,7 +223,9 @@ macro_rules! getter { ($self: ident, $variant: ident, $name: ident) => {{ let super_ = $self.as_ref(); if let PyNormalizerTypeWrapper::Single(ref norm) = super_.normalizer { - let wrapper = norm.read().unwrap(); + let wrapper = norm.read().expect( + "RwLock synchronisation primitive is poisoned, cannot get subtype of PyNormalizer", + ); if let PyNormalizerWrapper::Wrapped(NormalizerWrapper::$variant(o)) = (&*wrapper) { o.$name.clone() } else { @@ -234,7 +241,9 @@ macro_rules! setter { ($self: ident, $variant: ident, $name: ident, $value: expr) => {{ let super_ = $self.as_ref(); if let PyNormalizerTypeWrapper::Single(ref norm) = super_.normalizer { - let mut wrapper = norm.write().unwrap(); + let mut wrapper = norm.write().expect( + "RwLock synchronisation primitive is poisoned, cannot get subtype of PyNormalizer", + ); if let PyNormalizerWrapper::Wrapped(NormalizerWrapper::$variant(ref mut o)) = *wrapper { o.$name = $value; } @@ -410,25 +419,55 @@ impl PySequence { PyTuple::new(py, [PyList::empty(py)]) } - fn __len__(&self) -> usize { - 0 + fn __len__(self_: PyRef<'_, Self>) -> usize { + match &self_.as_ref().normalizer { + PyNormalizerTypeWrapper::Sequence(inner) => inner.len(), + PyNormalizerTypeWrapper::Single(_) => 1, + } } fn __getitem__(self_: PyRef<'_, Self>, py: Python<'_>, index: usize) -> PyResult> { match &self_.as_ref().normalizer { PyNormalizerTypeWrapper::Sequence(inner) => match inner.get(index) { - Some(item) => PyNormalizer::new(PyNormalizerTypeWrapper::Single(Arc::clone(item))) + Some(item) => PyNormalizer::new(PyNormalizerTypeWrapper::Single(item.clone())) .get_as_subtype(py), _ => Err(PyErr::new::( "Index not found", )), }, PyNormalizerTypeWrapper::Single(inner) => { - PyNormalizer::new(PyNormalizerTypeWrapper::Single(Arc::clone(inner))) - .get_as_subtype(py) + PyNormalizer::new(PyNormalizerTypeWrapper::Single(inner.clone())).get_as_subtype(py) } } } + + fn __setitem__(self_: PyRef<'_, Self>, index: usize, value: Bound<'_, PyAny>) -> PyResult<()> { + let norm: PyNormalizer = value.extract()?; + let PyNormalizerTypeWrapper::Single(norm) = norm.normalizer else { + return Err(PyException::new_err("normalizer should not be a sequence")); + }; + match &self_.as_ref().normalizer { + PyNormalizerTypeWrapper::Sequence(inner) => match inner.get(index) { + Some(item) => { + *item + .write() + .map_err(|_| PyException::new_err("RwLock synchronisation primitive is poisoned, cannot get subtype of PyNormalizer"))? = norm + .read() + .map_err(|_| PyException::new_err("RwLock synchronisation primitive is poisoned, cannot get subtype of PyNormalizer"))? + .clone(); + } + _ => { + return Err(PyErr::new::( + "Index not found", + )) + } + }, + PyNormalizerTypeWrapper::Single(_) => { + return Err(PyException::new_err("normalizer is not a sequence")) + } + }; + Ok(()) + } } /// Lowercase Normalizer @@ -570,9 +609,31 @@ impl PyReplace { ToPyResult(Replace::new(pattern, content)).into_py()?.into(), )) } + + #[getter] + fn get_pattern(_self: PyRef) -> PyResult<()> { + Err(PyException::new_err("Cannot get pattern")) + } + + #[setter] + fn set_pattern(_self: PyRef, _pattern: PyPattern) -> PyResult<()> { + Err(PyException::new_err( + "Cannot set pattern, please instantiate a new replace pattern instead", + )) + } + + #[getter] + fn get_content(self_: PyRef) -> String { + getter!(self_, Replace, content) + } + + #[setter] + fn set_content(self_: PyRef, content: String) { + setter!(self_, Replace, content, content) + } } -#[derive(Debug)] +#[derive(Clone, Debug)] pub(crate) struct CustomNormalizer { inner: PyObject, } @@ -615,7 +676,7 @@ impl<'de> Deserialize<'de> for CustomNormalizer { } } -#[derive(Debug, Deserialize)] +#[derive(Clone, Debug, Deserialize)] #[serde(untagged)] pub(crate) enum PyNormalizerWrapper { Custom(CustomNormalizer), @@ -634,13 +695,27 @@ impl Serialize for PyNormalizerWrapper { } } -#[derive(Debug, Clone, Deserialize)] -#[serde(untagged)] +#[derive(Debug, Clone)] pub(crate) enum PyNormalizerTypeWrapper { Sequence(Vec>>), Single(Arc>), } +/// XXX: we need to manually implement deserialize here because of the structure of the +/// PyNormalizerTypeWrapper enum. Given the underlying PyNormalizerWrapper can contain a Sequence, +/// default deserialization will give us a PyNormalizerTypeWrapper::Single(Sequence) when we'd like +/// it to be PyNormalizerTypeWrapper::Sequence(// ...). +impl<'de> Deserialize<'de> for PyNormalizerTypeWrapper { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let wrapper = NormalizerWrapper::deserialize(deserializer)?; + let py_wrapper: PyNormalizerWrapper = wrapper.into(); + Ok(py_wrapper.into()) + } +} + impl Serialize for PyNormalizerTypeWrapper { fn serialize(&self, serializer: S) -> Result where @@ -672,7 +747,17 @@ where I: Into, { fn from(norm: I) -> Self { - PyNormalizerTypeWrapper::Single(Arc::new(RwLock::new(norm.into()))) + let norm = norm.into(); + match norm { + PyNormalizerWrapper::Wrapped(NormalizerWrapper::Sequence(seq)) => { + PyNormalizerTypeWrapper::Sequence( + seq.into_iter() + .map(|e| Arc::new(RwLock::new(PyNormalizerWrapper::Wrapped(e.clone())))) + .collect(), + ) + } + _ => PyNormalizerTypeWrapper::Single(Arc::new(RwLock::new(norm))), + } } } @@ -690,10 +775,15 @@ where impl Normalizer for PyNormalizerTypeWrapper { fn normalize(&self, normalized: &mut NormalizedString) -> tk::Result<()> { match self { - PyNormalizerTypeWrapper::Single(inner) => inner.read().unwrap().normalize(normalized), - PyNormalizerTypeWrapper::Sequence(inner) => inner - .iter() - .try_for_each(|n| n.read().unwrap().normalize(normalized)), + PyNormalizerTypeWrapper::Single(inner) => inner + .read() + .map_err(|_| PyException::new_err("RwLock synchronisation primitive is poisoned, cannot get subtype of PyNormalizer"))? + .normalize(normalized), + PyNormalizerTypeWrapper::Sequence(inner) => inner.iter().try_for_each(|n| { + n.read() + .map_err(|_| PyException::new_err("RwLock synchronisation primitive is poisoned, cannot get subtype of PyNormalizer"))? + .normalize(normalized) + }), } } } @@ -793,18 +883,14 @@ mod test { let normalizer: PyNormalizer = serde_json::from_str(&sequence_string).unwrap(); match normalizer.normalizer { - PyNormalizerTypeWrapper::Single(inner) => match &*inner.as_ref().read().unwrap() { - PyNormalizerWrapper::Wrapped(NormalizerWrapper::Sequence(sequence)) => { - let normalizers = sequence.get_normalizers(); - assert_eq!(normalizers.len(), 1); - match normalizers[0] { - NormalizerWrapper::NFKC(_) => {} - _ => panic!("Expected NFKC"), - } - } - _ => panic!("Expected sequence"), - }, - _ => panic!("Expected single"), + PyNormalizerTypeWrapper::Sequence(inner) => { + assert_eq!(inner.len(), 1); + match *inner[0].as_ref().read().unwrap() { + PyNormalizerWrapper::Wrapped(NormalizerWrapper::NFKC(_)) => {} + _ => panic!("Expected NFKC"), + }; + } + _ => panic!("Expected sequence"), }; } } diff --git a/bindings/python/src/pre_tokenizers.rs b/bindings/python/src/pre_tokenizers.rs index 367dd143..8140ade1 100644 --- a/bindings/python/src/pre_tokenizers.rs +++ b/bindings/python/src/pre_tokenizers.rs @@ -1,6 +1,7 @@ use std::sync::{Arc, RwLock}; use pyo3::exceptions; +use pyo3::exceptions::PyException; use pyo3::prelude::*; use pyo3::types::*; use serde::ser::SerializeStruct; @@ -48,13 +49,17 @@ impl PyPreTokenizer { pub(crate) fn get_as_subtype(&self, py: Python<'_>) -> PyResult { let base = self.clone(); - Ok(match &self.pretok { + Ok(match self.pretok { PyPreTokenizerTypeWrapper::Sequence(_) => Py::new(py, (PySequence {}, base))? .into_pyobject(py)? .into_any() .into(), PyPreTokenizerTypeWrapper::Single(ref inner) => { - match &*inner.as_ref().read().unwrap() { + match &*inner + .as_ref() + .read() + .map_err(|_| PyException::new_err("RwLock synchronisation primitive is poisoned, cannot get subtype of PyPreTokenizer"))? + { PyPreTokenizerWrapper::Custom(_) => { Py::new(py, base)?.into_pyobject(py)?.into_any().into() } @@ -222,7 +227,7 @@ macro_rules! getter { let super_ = $self.as_ref(); if let PyPreTokenizerTypeWrapper::Single(ref single) = super_.pretok { if let PyPreTokenizerWrapper::Wrapped(PreTokenizerWrapper::$variant(ref pretok)) = - *single.read().unwrap() { + *single.read().expect("RwLock synchronisation primitive is poisoned, cannot get subtype of PyPreTokenizer") { pretok.$($name)+ } else { unreachable!() @@ -238,7 +243,7 @@ macro_rules! setter { let super_ = $self.as_ref(); if let PyPreTokenizerTypeWrapper::Single(ref single) = super_.pretok { if let PyPreTokenizerWrapper::Wrapped(PreTokenizerWrapper::$variant(ref mut pretok)) = - *single.write().unwrap() + *single.write().expect("RwLock synchronisation primitive is poisoned, cannot get subtype of PyPreTokenizer") { pretok.$name = $value; } @@ -248,7 +253,7 @@ macro_rules! setter { let super_ = $self.as_ref(); if let PyPreTokenizerTypeWrapper::Single(ref single) = super_.pretok { if let PyPreTokenizerWrapper::Wrapped(PreTokenizerWrapper::$variant(ref mut pretok)) = - *single.write().unwrap() + *single.write().expect("RwLock synchronisation primitive is poisoned, cannot get subtype of PyPreTokenizer") { pretok.$name($value); } @@ -292,6 +297,16 @@ impl PyByteLevel { setter!(self_, ByteLevel, use_regex, use_regex); } + #[getter] + fn get_trim_offsets(self_: PyRef) -> bool { + getter!(self_, ByteLevel, trim_offsets) + } + + #[setter] + fn set_trim_offsets(self_: PyRef, trim_offsets: bool) { + setter!(self_, ByteLevel, trim_offsets, trim_offsets) + } + #[new] #[pyo3(signature = (add_prefix_space = true, use_regex = true, **_kwargs), text_signature = "(self, add_prefix_space=True, use_regex=True)")] fn new( @@ -392,6 +407,52 @@ impl PySplit { fn __getnewargs__<'p>(&self, py: Python<'p>) -> PyResult> { PyTuple::new(py, [" ", "removed"]) } + + #[getter] + fn get_pattern(_self: PyRef) -> PyResult<()> { + Err(PyException::new_err("Cannot get pattern")) + } + + #[setter] + fn set_pattern(_self: PyRef, _pattern: PyPattern) -> PyResult<()> { + Err(PyException::new_err( + "Cannot set pattern, please instantiate a new split pattern instead", + )) + } + + #[getter] + fn get_behavior(self_: PyRef) -> String { + getter!(self_, Split, behavior).to_string().to_lowercase() + } + + #[setter] + fn set_behavior(self_: PyRef, behavior: String) -> PyResult<()> { + let behavior = match behavior.as_ref() { + "removed" => SplitDelimiterBehavior::Removed, + "isolated" => SplitDelimiterBehavior::Isolated, + "merged_with_previous" => SplitDelimiterBehavior::MergedWithPrevious, + "merged_with_next" => SplitDelimiterBehavior::MergedWithNext, + "contiguous" => SplitDelimiterBehavior::Contiguous, + _ => { + return Err(exceptions::PyValueError::new_err( + "Wrong value for SplitDelimiterBehavior, expected one of: \ + `removed, isolated, merged_with_previous, merged_with_next, contiguous`", + )) + } + }; + setter!(self_, Split, behavior, behavior); + Ok(()) + } + + #[getter] + fn get_invert(self_: PyRef) -> bool { + getter!(self_, Split, invert) + } + + #[setter] + fn set_invert(self_: PyRef, invert: bool) { + setter!(self_, Split, invert, invert) + } } /// This pre-tokenizer simply splits on the provided char. Works like `.split(delimiter)` @@ -458,6 +519,32 @@ impl PyPunctuation { fn new(behavior: PySplitDelimiterBehavior) -> (Self, PyPreTokenizer) { (PyPunctuation {}, Punctuation::new(behavior.into()).into()) } + + #[getter] + fn get_behavior(self_: PyRef) -> String { + getter!(self_, Punctuation, behavior) + .to_string() + .to_lowercase() + } + + #[setter] + fn set_behavior(self_: PyRef, behavior: String) -> PyResult<()> { + let behavior = match behavior.as_ref() { + "removed" => SplitDelimiterBehavior::Removed, + "isolated" => SplitDelimiterBehavior::Isolated, + "merged_with_previous" => SplitDelimiterBehavior::MergedWithPrevious, + "merged_with_next" => SplitDelimiterBehavior::MergedWithNext, + "contiguous" => SplitDelimiterBehavior::Contiguous, + _ => { + return Err(exceptions::PyValueError::new_err( + "Wrong value for SplitDelimiterBehavior, expected one of: \ + `removed, isolated, merged_with_previous, merged_with_next, contiguous`", + )) + } + }; + setter!(self_, Punctuation, behavior, behavior); + Ok(()) + } } /// This pre-tokenizer composes other pre_tokenizers and applies them in sequence @@ -491,20 +578,47 @@ impl PySequence { fn __getitem__(self_: PyRef<'_, Self>, py: Python<'_>, index: usize) -> PyResult> { match &self_.as_ref().pretok { PyPreTokenizerTypeWrapper::Sequence(inner) => match inner.get(index) { - Some(item) => { - PyPreTokenizer::new(PyPreTokenizerTypeWrapper::Single(Arc::clone(item))) - .get_as_subtype(py) - } + Some(item) => PyPreTokenizer::new(PyPreTokenizerTypeWrapper::Single(item.clone())) + .get_as_subtype(py), _ => Err(PyErr::new::( "Index not found", )), }, - PyPreTokenizerTypeWrapper::Single(inner) => { - PyPreTokenizer::new(PyPreTokenizerTypeWrapper::Single(Arc::clone(inner))) - .get_as_subtype(py) - } + _ => Err(PyErr::new::( + "This processor is not a Sequence, it does not support __getitem__", + )), } } + + fn __setitem__(self_: PyRef<'_, Self>, index: usize, value: Bound<'_, PyAny>) -> PyResult<()> { + let pretok: PyPreTokenizer = value.extract()?; + let PyPreTokenizerTypeWrapper::Single(pretok) = pretok.pretok else { + return Err(PyException::new_err( + "pre tokenizer should not be a sequence", + )); + }; + match &self_.as_ref().pretok { + PyPreTokenizerTypeWrapper::Sequence(inner) => match inner.get(index) { + Some(item) => { + *item + .write() + .map_err(|_| PyException::new_err("RwLock synchronisation primitive is poisoned, cannot get subtype of PyPreTokenizer"))? = (*pretok + .read() + .map_err(|_| PyException::new_err("RwLock synchronisation primitive is poisoned, cannot get subtype of PyPreTokenizer"))?) + .clone(); + } + _ => { + return Err(PyErr::new::( + "Index not found", + )) + } + }, + PyPreTokenizerTypeWrapper::Single(_) => { + return Err(PyException::new_err("pre tokenizer is not a sequence")) + } + }; + Ok(()) + } } pub(crate) fn from_string(string: String) -> Result { @@ -565,13 +679,7 @@ impl PyMetaspace { #[getter] fn get_prepend_scheme(self_: PyRef) -> String { // Assuming Metaspace has a method to get the prepend_scheme as a string - let scheme: PrependScheme = getter!(self_, Metaspace, get_prepend_scheme()); - match scheme { - PrependScheme::First => "first", - PrependScheme::Never => "never", - PrependScheme::Always => "always", - } - .to_string() + getter!(self_, Metaspace, get_prepend_scheme()).to_string() } #[setter] @@ -642,6 +750,7 @@ impl PyUnicodeScripts { } } +#[derive(Clone)] pub(crate) struct CustomPreTokenizer { inner: PyObject, } @@ -685,7 +794,7 @@ impl<'de> Deserialize<'de> for CustomPreTokenizer { } } -#[derive(Deserialize)] +#[derive(Clone, Deserialize)] #[serde(untagged)] pub(crate) enum PyPreTokenizerWrapper { Custom(CustomPreTokenizer), @@ -704,13 +813,23 @@ impl Serialize for PyPreTokenizerWrapper { } } -#[derive(Clone, Deserialize)] -#[serde(untagged)] +#[derive(Clone)] pub(crate) enum PyPreTokenizerTypeWrapper { Sequence(Vec>>), Single(Arc>), } +impl<'de> Deserialize<'de> for PyPreTokenizerTypeWrapper { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let wrapper = PreTokenizerWrapper::deserialize(deserializer)?; + let py_wrapper: PyPreTokenizerWrapper = wrapper.into(); + Ok(py_wrapper.into()) + } +} + impl Serialize for PyPreTokenizerTypeWrapper { fn serialize(&self, serializer: S) -> Result where @@ -742,7 +861,17 @@ where I: Into, { fn from(pretok: I) -> Self { - PyPreTokenizerTypeWrapper::Single(Arc::new(RwLock::new(pretok.into()))) + let pretok = pretok.into(); + match pretok { + PyPreTokenizerWrapper::Wrapped(PreTokenizerWrapper::Sequence(seq)) => { + PyPreTokenizerTypeWrapper::Sequence( + seq.into_iter() + .map(|e| Arc::new(RwLock::new(PyPreTokenizerWrapper::Wrapped(e.clone())))) + .collect(), + ) + } + _ => PyPreTokenizerTypeWrapper::Single(Arc::new(RwLock::new(pretok))), + } } } @@ -760,10 +889,15 @@ where impl PreTokenizer for PyPreTokenizerTypeWrapper { fn pre_tokenize(&self, pretok: &mut PreTokenizedString) -> tk::Result<()> { match self { - PyPreTokenizerTypeWrapper::Single(inner) => inner.read().unwrap().pre_tokenize(pretok), - PyPreTokenizerTypeWrapper::Sequence(inner) => inner - .iter() - .try_for_each(|n| n.read().unwrap().pre_tokenize(pretok)), + PyPreTokenizerTypeWrapper::Single(inner) => inner + .read() + .map_err(|_| PyException::new_err("RwLock synchronisation primitive is poisoned, cannot get subtype of PyPreTokenizer"))? + .pre_tokenize(pretok), + PyPreTokenizerTypeWrapper::Sequence(inner) => inner.iter().try_for_each(|n| { + n.read() + .map_err(|_| PyException::new_err("RwLock synchronisation primitive is poisoned, cannot get subtype of PyPreTokenizer"))? + .pre_tokenize(pretok) + }), } } } diff --git a/bindings/python/src/processors.rs b/bindings/python/src/processors.rs index d558c40b..07784afa 100644 --- a/bindings/python/src/processors.rs +++ b/bindings/python/src/processors.rs @@ -1,17 +1,20 @@ use std::convert::TryInto; use std::sync::Arc; - -use pyo3::exceptions; -use pyo3::prelude::*; -use pyo3::types::*; +use std::sync::RwLock; use crate::encoding::PyEncoding; use crate::error::ToPyResult; +use pyo3::exceptions; +use pyo3::exceptions::PyException; +use pyo3::prelude::*; +use pyo3::types::*; +use serde::ser::SerializeStruct; +use serde::Deserializer; +use serde::Serializer; use serde::{Deserialize, Serialize}; use tk::processors::bert::BertProcessing; use tk::processors::byte_level::ByteLevel; use tk::processors::roberta::RobertaProcessing; -use tk::processors::sequence::Sequence; use tk::processors::template::{SpecialToken, Template}; use tk::processors::PostProcessorWrapper; use tk::{Encoding, PostProcessor}; @@ -30,42 +33,67 @@ use tokenizers as tk; #[derive(Clone, Deserialize, Serialize)] #[serde(transparent)] pub struct PyPostProcessor { - pub processor: Arc, + processor: PyPostProcessorTypeWrapper, +} + +impl From for PyPostProcessor +where + I: Into, +{ + fn from(processor: I) -> Self { + PyPostProcessor { + processor: processor.into().into(), + } + } } impl PyPostProcessor { - pub fn new(processor: Arc) -> Self { + pub(crate) fn new(processor: PyPostProcessorTypeWrapper) -> Self { PyPostProcessor { processor } } pub(crate) fn get_as_subtype(&self, py: Python<'_>) -> PyResult { let base = self.clone(); - Ok(match self.processor.as_ref() { - PostProcessorWrapper::ByteLevel(_) => Py::new(py, (PyByteLevel {}, base))? + Ok( + match self.processor { + PyPostProcessorTypeWrapper::Sequence(_) => Py::new(py, (PySequence {}, base))? .into_pyobject(py)? .into_any() .into(), - PostProcessorWrapper::Bert(_) => Py::new(py, (PyBertProcessing {}, base))? - .into_pyobject(py)? - .into_any() - .into(), - PostProcessorWrapper::Roberta(_) => Py::new(py, (PyRobertaProcessing {}, base))? - .into_pyobject(py)? - .into_any() - .into(), - PostProcessorWrapper::Template(_) => Py::new(py, (PyTemplateProcessing {}, base))? - .into_pyobject(py)? - .into_any() - .into(), - PostProcessorWrapper::Sequence(_) => Py::new(py, (PySequence {}, base))? - .into_pyobject(py)? - .into_any() - .into(), - }) + PyPostProcessorTypeWrapper::Single(ref inner) => { + + match &*inner.read().map_err(|_| { + PyException::new_err("RwLock synchronisation primitive is poisoned, cannot get subtype of PyPostProcessor") + })? { + PostProcessorWrapper::ByteLevel(_) => Py::new(py, (PyByteLevel {}, base))? + .into_pyobject(py)? + .into_any() + .into(), + PostProcessorWrapper::Bert(_) => Py::new(py, (PyBertProcessing {}, base))? + .into_pyobject(py)? + .into_any() + .into(), + PostProcessorWrapper::Roberta(_) => Py::new(py, (PyRobertaProcessing {}, base))? + .into_pyobject(py)? + .into_any() + .into(), + PostProcessorWrapper::Template(_) => Py::new(py, (PyTemplateProcessing {}, base))? + .into_pyobject(py)? + .into_any() + .into(), + PostProcessorWrapper::Sequence(_) => Py::new(py, (PySequence {}, base))? + .into_pyobject(py)? + .into_any() + .into(), + } + } + } + ) } } impl PostProcessor for PyPostProcessor { + // TODO: update signature to `tk::Result` fn added_tokens(&self, is_pair: bool) -> usize { self.processor.added_tokens(is_pair) } @@ -83,7 +111,7 @@ impl PostProcessor for PyPostProcessor { #[pymethods] impl PyPostProcessor { fn __getstate__(&self, py: Python) -> PyResult { - let data = serde_json::to_string(self.processor.as_ref()).map_err(|e| { + let data = serde_json::to_string(&self.processor).map_err(|e| { exceptions::PyException::new_err(format!( "Error while attempting to pickle PostProcessor: {}", e @@ -116,8 +144,8 @@ impl PyPostProcessor { /// Returns: /// :obj:`int`: The number of tokens to add #[pyo3(text_signature = "(self, is_pair)")] - fn num_special_tokens_to_add(&self, is_pair: bool) -> usize { - self.processor.added_tokens(is_pair) + fn num_special_tokens_to_add(&self, is_pair: bool) -> PyResult { + Ok(self.processor.added_tokens(is_pair)) } /// Post-process the given encodings, generating the final one @@ -162,6 +190,132 @@ impl PyPostProcessor { } } +macro_rules! getter { + ($self: ident, $variant: ident, $($name: tt)+) => {{ + let super_ = $self.as_ref(); + if let PyPostProcessorTypeWrapper::Single(ref single) = super_.processor { + if let PostProcessorWrapper::$variant(ref post) = *single.read().expect( + "RwLock synchronisation primitive is poisoned, cannot get subtype of PyPostProcessor" + ) { + post.$($name)+ + } else { + unreachable!() + } + } else { + unreachable!() + } + }}; +} + +macro_rules! setter { + ($self: ident, $variant: ident, $name: ident, $value: expr) => {{ + let super_ = $self.as_ref(); + if let PyPostProcessorTypeWrapper::Single(ref single) = super_.processor { + if let PostProcessorWrapper::$variant(ref mut post) = *single.write().expect( + "RwLock synchronisation primitive is poisoned, cannot get subtype of PyPostProcessor", + ) { + post.$name = $value; + } + } + }}; + ($self: ident, $variant: ident, @$name: ident, $value: expr) => {{ + let super_ = $self.as_ref(); + if let PyPostProcessorTypeWrapper::Single(ref single) = super_.processor { + if let PostProcessorWrapper::$variant(ref mut post) = *single.write().expect( + "RwLock synchronisation primitive is poisoned, cannot get subtype of PyPostProcessor", + ) { + post.$name($value); + } + } + };}; +} + +#[derive(Clone)] +pub(crate) enum PyPostProcessorTypeWrapper { + Sequence(Vec>>), + Single(Arc>), +} + +impl PostProcessor for PyPostProcessorTypeWrapper { + fn added_tokens(&self, is_pair: bool) -> usize { + match self { + PyPostProcessorTypeWrapper::Single(inner) => inner + .read() + .expect("RwLock synchronisation primitive is poisoned, cannot get subtype of PyPostProcessor") + .added_tokens(is_pair), + PyPostProcessorTypeWrapper::Sequence(inner) => inner.iter().map(|p| { + p.read() + .expect("RwLock synchronisation primitive is poisoned, cannot get subtype of PyPostProcessor") + .added_tokens(is_pair) + }).sum::(), + } + } + + fn process_encodings( + &self, + mut encodings: Vec, + add_special_tokens: bool, + ) -> tk::Result> { + match self { + PyPostProcessorTypeWrapper::Single(inner) => inner + .read() + .map_err(|_| PyException::new_err("RwLock synchronisation primitive is poisoned, cannot get subtype of PyPreTokenizer"))? + .process_encodings(encodings, add_special_tokens), + PyPostProcessorTypeWrapper::Sequence(inner) => { + for processor in inner.iter() { + encodings = processor + .read() + .map_err(|_| PyException::new_err("RwLock synchronisation primitive is poisoned, cannot get subtype of PyPreTokenizer"))? + .process_encodings(encodings, add_special_tokens)?; + } + Ok(encodings) + }, + } + } +} + +impl<'de> Deserialize<'de> for PyPostProcessorTypeWrapper { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let wrapper = PostProcessorWrapper::deserialize(deserializer)?; + Ok(wrapper.into()) + } +} + +impl Serialize for PyPostProcessorTypeWrapper { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + match self { + PyPostProcessorTypeWrapper::Sequence(seq) => { + let mut ser = serializer.serialize_struct("Sequence", 2)?; + ser.serialize_field("type", "Sequence")?; + ser.serialize_field("processors", seq)?; + ser.end() + } + PyPostProcessorTypeWrapper::Single(inner) => inner.serialize(serializer), + } + } +} + +impl From for PyPostProcessorTypeWrapper +where + I: Into, +{ + fn from(processor: I) -> Self { + let processor = processor.into(); + match processor { + PostProcessorWrapper::Sequence(seq) => PyPostProcessorTypeWrapper::Sequence( + seq.into_iter().map(|p| Arc::new(RwLock::new(p))).collect(), + ), + _ => PyPostProcessorTypeWrapper::Single(Arc::new(RwLock::new(processor.clone()))), + } + } +} + /// This post-processor takes care of adding the special tokens needed by /// a Bert model: /// @@ -181,15 +335,46 @@ impl PyBertProcessing { #[new] #[pyo3(text_signature = "(self, sep, cls)")] fn new(sep: (String, u32), cls: (String, u32)) -> (Self, PyPostProcessor) { - ( - PyBertProcessing {}, - PyPostProcessor::new(Arc::new(BertProcessing::new(sep, cls).into())), - ) + (PyBertProcessing {}, BertProcessing::new(sep, cls).into()) } fn __getnewargs__<'p>(&self, py: Python<'p>) -> PyResult> { PyTuple::new(py, [("", 0), ("", 0)]) } + + #[getter] + fn get_sep(self_: PyRef) -> Result, PyErr> { + let py = self_.py(); + let (tok, id) = getter!(self_, Bert, get_sep_copy()); + PyTuple::new( + py, + Vec::::from([tok.into_pyobject(py)?.into(), id.into_pyobject(py)?.into()]), + ) + } + + #[setter] + fn set_sep(self_: PyRef, sep: Bound<'_, PyTuple>) -> PyResult<()> { + let sep = sep.extract()?; + setter!(self_, Bert, sep, sep); + Ok(()) + } + + #[getter] + fn get_cls(self_: PyRef) -> Result, PyErr> { + let py = self_.py(); + let (tok, id) = getter!(self_, Bert, get_cls_copy()); + PyTuple::new( + py, + Vec::::from([tok.into_pyobject(py)?.into(), id.into_pyobject(py)?.into()]), + ) + } + + #[setter] + fn set_cls(self_: PyRef, cls: Bound<'_, PyTuple>) -> PyResult<()> { + let cls = cls.extract()?; + setter!(self_, Bert, cls, cls); + Ok(()) + } } /// This post-processor takes care of adding the special tokens needed by @@ -231,15 +416,66 @@ impl PyRobertaProcessing { let proc = RobertaProcessing::new(sep, cls) .trim_offsets(trim_offsets) .add_prefix_space(add_prefix_space); - ( - PyRobertaProcessing {}, - PyPostProcessor::new(Arc::new(proc.into())), - ) + (PyRobertaProcessing {}, proc.into()) } fn __getnewargs__<'p>(&self, py: Python<'p>) -> PyResult> { PyTuple::new(py, [("", 0), ("", 0)]) } + + #[getter] + fn get_sep(self_: PyRef) -> Result, PyErr> { + let py = self_.py(); + let (tok, id) = getter!(self_, Roberta, get_sep_copy()); + PyTuple::new( + py, + Vec::::from([tok.into_pyobject(py)?.into(), id.into_pyobject(py)?.into()]), + ) + } + + #[setter] + fn set_sep(self_: PyRef, sep: Bound<'_, PyTuple>) -> PyResult<()> { + let sep = sep.extract()?; + setter!(self_, Roberta, sep, sep); + Ok(()) + } + + #[getter] + fn get_cls(self_: PyRef) -> Result, PyErr> { + let py = self_.py(); + let (tok, id) = getter!(self_, Roberta, get_cls_copy()); + PyTuple::new( + py, + Vec::::from([tok.into_pyobject(py)?.into(), id.into_pyobject(py)?.into()]), + ) + } + + #[setter] + fn set_cls(self_: PyRef, cls: Bound<'_, PyTuple>) -> PyResult<()> { + let cls = cls.extract()?; + setter!(self_, Roberta, cls, cls); + Ok(()) + } + + #[getter] + fn get_trim_offsets(self_: PyRef) -> bool { + getter!(self_, Roberta, trim_offsets) + } + + #[setter] + fn set_trim_offsets(self_: PyRef, trim_offsets: bool) { + setter!(self_, Roberta, trim_offsets, trim_offsets) + } + + #[getter] + fn get_add_prefix_space(self_: PyRef) -> bool { + getter!(self_, Roberta, add_prefix_space) + } + + #[setter] + fn set_add_prefix_space(self_: PyRef, add_prefix_space: bool) { + setter!(self_, Roberta, add_prefix_space, add_prefix_space) + } } /// This post-processor takes care of trimming the offsets. @@ -255,21 +491,58 @@ pub struct PyByteLevel {} #[pymethods] impl PyByteLevel { #[new] - #[pyo3(signature = (trim_offsets = None, **_kwargs), text_signature = "(self, trim_offsets=True)")] + #[pyo3(signature = (add_prefix_space = None, trim_offsets = None, use_regex = None, **_kwargs), text_signature = "(self, trim_offsets=True)")] fn new( + add_prefix_space: Option, trim_offsets: Option, + use_regex: Option, _kwargs: Option<&Bound<'_, PyDict>>, ) -> (Self, PyPostProcessor) { let mut byte_level = ByteLevel::default(); + if let Some(aps) = add_prefix_space { + byte_level = byte_level.add_prefix_space(aps); + } + if let Some(to) = trim_offsets { byte_level = byte_level.trim_offsets(to); } - ( - PyByteLevel {}, - PyPostProcessor::new(Arc::new(byte_level.into())), - ) + if let Some(ur) = use_regex { + byte_level = byte_level.use_regex(ur); + } + + (PyByteLevel {}, byte_level.into()) + } + + #[getter] + fn get_add_prefix_space(self_: PyRef) -> bool { + getter!(self_, ByteLevel, add_prefix_space) + } + + #[setter] + fn set_add_prefix_space(self_: PyRef, add_prefix_space: bool) { + setter!(self_, ByteLevel, add_prefix_space, add_prefix_space) + } + + #[getter] + fn get_trim_offsets(self_: PyRef) -> bool { + getter!(self_, ByteLevel, trim_offsets) + } + + #[setter] + fn set_trim_offsets(self_: PyRef, trim_offsets: bool) { + setter!(self_, ByteLevel, trim_offsets, trim_offsets) + } + + #[getter] + fn get_use_regex(self_: PyRef) -> bool { + getter!(self_, ByteLevel, use_regex) + } + + #[setter] + fn set_use_regex(self_: PyRef, use_regex: bool) { + setter!(self_, ByteLevel, use_regex, use_regex) } } @@ -430,10 +703,26 @@ impl PyTemplateProcessing { .build() .map_err(|e| exceptions::PyValueError::new_err(e.to_string()))?; - Ok(( - PyTemplateProcessing {}, - PyPostProcessor::new(Arc::new(processor.into())), - )) + Ok((PyTemplateProcessing {}, processor.into())) + } + + #[getter] + fn get_single(self_: PyRef) -> String { + getter!(self_, Template, get_single()) + } + + #[setter] + fn set_single(self_: PyRef, single: PyTemplate) -> PyResult<()> { + let template: Template = Template::from(single); + let super_ = self_.as_ref(); + if let PyPostProcessorTypeWrapper::Single(ref inner) = super_.processor { + if let PostProcessorWrapper::Template(ref mut post) = *inner + .write() + .map_err(|_| PyException::new_err("RwLock synchronisation primitive is poisoned, cannot get subtype of PyPostProcessor"))? { + post.set_single(template); + } + } + Ok(()) } } @@ -444,27 +733,79 @@ impl PyTemplateProcessing { /// The processors that need to be chained #[pyclass(extends=PyPostProcessor, module = "tokenizers.processors", name = "Sequence")] pub struct PySequence {} + #[pymethods] impl PySequence { #[new] #[pyo3(signature = (processors_py), text_signature = "(self, processors)")] - fn new(processors_py: &Bound<'_, PyList>) -> (Self, PyPostProcessor) { - let mut processors: Vec = Vec::with_capacity(processors_py.len()); + fn new(processors_py: &Bound<'_, PyList>) -> PyResult<(Self, PyPostProcessor)> { + let mut processors = Vec::with_capacity(processors_py.len()); for n in processors_py.iter() { - let processor: PyRef = n.extract().unwrap(); - let processor = processor.processor.as_ref(); - processors.push(processor.clone()); + let processor: PyRef = n.extract()?; + match &processor.processor { + PyPostProcessorTypeWrapper::Sequence(inner) => { + processors.extend(inner.iter().cloned()) + } + PyPostProcessorTypeWrapper::Single(inner) => processors.push(inner.clone()), + } } - let sequence_processor = Sequence::new(processors); - ( + Ok(( PySequence {}, - PyPostProcessor::new(Arc::new(PostProcessorWrapper::Sequence(sequence_processor))), - ) + PyPostProcessor::new(PyPostProcessorTypeWrapper::Sequence(processors)), + )) } fn __getnewargs__<'p>(&self, py: Python<'p>) -> PyResult> { PyTuple::new(py, [PyList::empty(py)]) } + + fn __getitem__(self_: PyRef<'_, Self>, py: Python<'_>, index: usize) -> PyResult> { + match &self_.as_ref().processor { + PyPostProcessorTypeWrapper::Sequence(ref inner) => match inner.get(index) { + Some(item) => { + PyPostProcessor::new(PyPostProcessorTypeWrapper::Single(item.clone())) + .get_as_subtype(py) + } + _ => Err(PyErr::new::( + "Index not found", + )), + }, + _ => Err(PyErr::new::( + "This processor is not a Sequence, it does not support __getitem__", + )), + } + } + + fn __setitem__(self_: PyRef<'_, Self>, index: usize, value: Bound<'_, PyAny>) -> PyResult<()> { + let processor: PyPostProcessor = value.extract()?; + let PyPostProcessorTypeWrapper::Single(processor) = processor.processor else { + return Err(PyException::new_err("processor should not be a sequence")); + }; + + match &self_.as_ref().processor { + PyPostProcessorTypeWrapper::Sequence(inner) => match inner.get(index) { + Some(item) => { + *item + .write() + .map_err(|_| PyException::new_err("RwLock synchronisation primitive is poisoned, cannot get subtype of PyPostProcessor"))? = processor + .read() + .map_err(|_| PyException::new_err("RwLock synchronisation primitive is poisoned, cannot get subtype of PyPostProcessor"))? + .clone(); + } + _ => { + return Err(PyErr::new::( + "Index not found", + )) + } + }, + _ => { + return Err(PyException::new_err( + "This processor is not a Sequence, it does not support __setitem__", + )) + } + }; + Ok(()) + } } /// Processors Module @@ -481,20 +822,20 @@ pub fn processors(m: &Bound<'_, PyModule>) -> PyResult<()> { #[cfg(test)] mod test { - use std::sync::Arc; + use std::sync::{Arc, RwLock}; use pyo3::prelude::*; use tk::processors::bert::BertProcessing; use tk::processors::PostProcessorWrapper; - use crate::processors::PyPostProcessor; + use crate::processors::{PyPostProcessor, PyPostProcessorTypeWrapper}; #[test] fn get_subtype() { Python::with_gil(|py| { - let py_proc = PyPostProcessor::new(Arc::new( - BertProcessing::new(("SEP".into(), 0), ("CLS".into(), 1)).into(), - )); + let py_proc = PyPostProcessor::new(PyPostProcessorTypeWrapper::Single(Arc::new( + RwLock::new(BertProcessing::new(("SEP".into(), 0), ("CLS".into(), 1)).into()), + ))); let py_bert = py_proc.get_as_subtype(py).unwrap(); assert_eq!( "BertProcessing", @@ -510,21 +851,29 @@ mod test { let rs_processing_ser = serde_json::to_string(&rs_processing).unwrap(); let rs_wrapper_ser = serde_json::to_string(&rs_wrapper).unwrap(); - let py_processing = PyPostProcessor::new(Arc::new(rs_wrapper)); + let py_processing = PyPostProcessor::new(PyPostProcessorTypeWrapper::Single(Arc::new( + RwLock::new(rs_wrapper), + ))); let py_ser = serde_json::to_string(&py_processing).unwrap(); assert_eq!(py_ser, rs_processing_ser); assert_eq!(py_ser, rs_wrapper_ser); let py_processing: PyPostProcessor = serde_json::from_str(&rs_processing_ser).unwrap(); - match py_processing.processor.as_ref() { - PostProcessorWrapper::Bert(_) => (), - _ => panic!("Expected Bert postprocessor."), + match py_processing.processor { + PyPostProcessorTypeWrapper::Single(inner) => match *inner.as_ref().read().unwrap() { + PostProcessorWrapper::Bert(_) => (), + _ => panic!("Expected Bert postprocessor."), + }, + _ => panic!("Expected a single processor, got a sequence"), } let py_processing: PyPostProcessor = serde_json::from_str(&rs_wrapper_ser).unwrap(); - match py_processing.processor.as_ref() { - PostProcessorWrapper::Bert(_) => (), - _ => panic!("Expected Bert postprocessor."), - } + match py_processing.processor { + PyPostProcessorTypeWrapper::Single(inner) => match *inner.as_ref().read().unwrap() { + PostProcessorWrapper::Bert(_) => (), + _ => panic!("Expected Bert postprocessor."), + }, + _ => panic!("Expected a single processor, got a sequence"), + }; } } diff --git a/bindings/python/src/utils/normalization.rs b/bindings/python/src/utils/normalization.rs index 9de0ece3..21a9ae96 100644 --- a/bindings/python/src/utils/normalization.rs +++ b/bindings/python/src/utils/normalization.rs @@ -117,6 +117,12 @@ impl From for SplitDelimiterBehavior { } } +impl From for PySplitDelimiterBehavior { + fn from(v: SplitDelimiterBehavior) -> Self { + Self(v) + } +} + fn filter(normalized: &mut NormalizedString, func: &Bound<'_, PyAny>) -> PyResult<()> { let err = "`filter` expect a callable with the signature: `fn(char) -> bool`"; diff --git a/bindings/python/tests/bindings/test_normalizers.py b/bindings/python/tests/bindings/test_normalizers.py index 109f9ad2..99ab07d3 100644 --- a/bindings/python/tests/bindings/test_normalizers.py +++ b/bindings/python/tests/bindings/test_normalizers.py @@ -3,7 +3,16 @@ import pickle import pytest from tokenizers import NormalizedString -from tokenizers.normalizers import BertNormalizer, Lowercase, Normalizer, Sequence, Strip, Prepend +from tokenizers.normalizers import ( + BertNormalizer, + Lowercase, + Normalizer, + Precompiled, + Sequence, + Strip, + Prepend, + Replace, +) class TestBertNormalizer: @@ -67,14 +76,57 @@ class TestSequence: output = normalizer.normalize_str(" HELLO ") assert output == "hello" - def test_items(self): - normalizers = Sequence([BertNormalizer(True, True), Prepend(), Strip()]) + def test_set_item(self): + normalizers = Sequence( + [ + BertNormalizer(True, True), + Prepend(prepend="test"), + ] + ) + assert normalizers[0].__class__ == BertNormalizer assert normalizers[1].__class__ == Prepend - normalizers[0].lowercase = False - assert not normalizers[0].lowercase - assert normalizers[2].__class__ == Strip + normalizers[1] = Strip() + assert normalizers[1].__class__ == Strip with pytest.raises(IndexError): - print(normalizers[3]) + print(normalizers[2]) + + def test_item_getters_and_setters(self): + normalizers = Sequence( + [ + BertNormalizer(clean_text=True, handle_chinese_chars=True, strip_accents=True, lowercase=True), + Strip(left=True, right=True), + Prepend(prepend="_"), + Replace(pattern="something", content="else"), + ] + ) + + assert normalizers[0].__class__ == BertNormalizer + normalizers[0].clean_text = False + normalizers[0].handle_chinese_chars = False + normalizers[0].strip_accents = False + normalizers[0].lowercase = False + assert not normalizers[0].clean_text + assert not normalizers[0].handle_chinese_chars + assert not normalizers[0].strip_accents + assert not normalizers[0].lowercase + + assert normalizers[1].__class__ == Strip + normalizers[1].left = False + normalizers[1].right = False + assert not normalizers[1].left + assert not normalizers[1].right + + assert normalizers[2].__class__ == Prepend + normalizers[2].prepend = " " + assert normalizers[2].prepend == " " + + assert normalizers[3].__class__ == Replace + with pytest.raises(Exception): + normalizers[3].pattern = "test" + with pytest.raises(Exception): + print(normalizers[3].pattern) + normalizers[3].content = "test" + assert normalizers[3].content == "test" class TestLowercase: diff --git a/bindings/python/tests/bindings/test_pre_tokenizers.py b/bindings/python/tests/bindings/test_pre_tokenizers.py index 80086f42..3611930a 100644 --- a/bindings/python/tests/bindings/test_pre_tokenizers.py +++ b/bindings/python/tests/bindings/test_pre_tokenizers.py @@ -169,12 +169,69 @@ class TestSequence: ("?", (29, 30)), ] - def test_items(self): - pre_tokenizers = Sequence([Metaspace("a", "never", split=True), Punctuation()]) - assert pre_tokenizers[1].__class__ == Punctuation - assert pre_tokenizers[0].__class__ == Metaspace - pre_tokenizers[0].split = False - assert not pre_tokenizers[0].split + def test_set_item(self): + pre_tokenizers = Sequence( + [ + ByteLevel(), + Split(pattern="/test/", behavior="removed"), + ] + ) + assert pre_tokenizers[0].__class__ == ByteLevel + assert pre_tokenizers[1].__class__ == Split + pre_tokenizers[1] = Metaspace() + assert pre_tokenizers[1].__class__ == Metaspace + with pytest.raises(IndexError): + print(pre_tokenizers[2]) + + def test_item_getters_and_setters(self): + pre_tokenizers = Sequence( + [ + ByteLevel(add_prefix_space=True, trim_offsets=True, use_regex=True), + Split(pattern="/test/", behavior="removed", invert=False), + Metaspace("a", "never", split=False), + CharDelimiterSplit(delimiter=" "), + Punctuation(behavior="removed"), + Digits(individual_digits=True), + ] + ) + + assert pre_tokenizers[0].__class__ == ByteLevel + pre_tokenizers[0].add_prefix_space = False + pre_tokenizers[0].trim_offsets = False + pre_tokenizers[0].use_regex = False + assert not pre_tokenizers[0].add_prefix_space + assert not pre_tokenizers[0].trim_offsets + assert not pre_tokenizers[0].use_regex + + assert pre_tokenizers[1].__class__ == Split + with pytest.raises(Exception): + pre_tokenizers[1].pattern = "/pattern/" + pre_tokenizers[1].behavior = "isolated" + pre_tokenizers[1].invert = True + with pytest.raises(Exception): + pre_tokenizers[1].pattern + assert pre_tokenizers[1].behavior == "isolated" + assert pre_tokenizers[1].invert + + assert pre_tokenizers[2].__class__ == Metaspace + pre_tokenizers[2].replacement = " " + pre_tokenizers[2].prepend_scheme = "always" + pre_tokenizers[2].split = True + assert pre_tokenizers[2].replacement == " " + assert pre_tokenizers[2].prepend_scheme == "always" + assert pre_tokenizers[2].split + + assert pre_tokenizers[3].__class__ == CharDelimiterSplit + pre_tokenizers[3].delimiter = "_" + assert pre_tokenizers[3].delimiter == "_" + + assert pre_tokenizers[4].__class__ == Punctuation + pre_tokenizers[4].behavior = "isolated" + assert pre_tokenizers[4].behavior == "isolated" + + assert pre_tokenizers[5].__class__ == Digits + pre_tokenizers[5].individual_digits = False + assert not pre_tokenizers[5].individual_digits class TestDigits: diff --git a/bindings/python/tests/bindings/test_processors.py b/bindings/python/tests/bindings/test_processors.py index 842754a6..3038d869 100644 --- a/bindings/python/tests/bindings/test_processors.py +++ b/bindings/python/tests/bindings/test_processors.py @@ -227,3 +227,18 @@ class TestSequenceProcessing: # assert pair.ids == [1, 2, 3, 4, 5, 0, 6, 0] assert pair.type_ids == [0, 0, 0, 0, 0, 0, 1, 1] assert pair.offsets == [(0, 0), (0, 2), (3, 7), (8, 10), (12, 16), (0, 0), (0, 4), (0, 0)] + + def test_items(self): + processors = Sequence([RobertaProcessing(("", 1), ("", 0)), ByteLevel()]) + assert processors[0].__class__ == RobertaProcessing + assert processors[1].__class__ == ByteLevel + processors[0] = ByteLevel(add_prefix_space=False, trim_offsets=False, use_regex=False) + print(processors[0]) + processors[0].add_prefix_space = True + processors[0].trim_offsets = True + processors[0].use_regex = True + print(processors[0]) + assert processors[0].__class__ == ByteLevel + assert processors[0].add_prefix_space + assert processors[0].trim_offsets + assert processors[0].use_regex diff --git a/tokenizers/src/normalizers/mod.rs b/tokenizers/src/normalizers/mod.rs index 2a786d70..f400f13d 100644 --- a/tokenizers/src/normalizers/mod.rs +++ b/tokenizers/src/normalizers/mod.rs @@ -43,14 +43,14 @@ impl<'de> Deserialize<'de> for NormalizerWrapper { where D: Deserializer<'de>, { - #[derive(Deserialize)] + #[derive(Debug, Deserialize)] pub struct Tagged { #[serde(rename = "type")] variant: EnumType, #[serde(flatten)] rest: serde_json::Value, } - #[derive(Serialize, Deserialize)] + #[derive(Debug, Serialize, Deserialize)] pub enum EnumType { Bert, Strip, @@ -168,7 +168,7 @@ impl<'de> Deserialize<'de> for NormalizerWrapper { NormalizerUntagged::NFD(bpe) => NormalizerWrapper::NFD(bpe), NormalizerUntagged::NFKC(bpe) => NormalizerWrapper::NFKC(bpe), NormalizerUntagged::NFKD(bpe) => NormalizerWrapper::NFKD(bpe), - NormalizerUntagged::Sequence(bpe) => NormalizerWrapper::Sequence(bpe), + NormalizerUntagged::Sequence(seq) => NormalizerWrapper::Sequence(seq), NormalizerUntagged::Lowercase(bpe) => NormalizerWrapper::Lowercase(bpe), NormalizerUntagged::Nmt(bpe) => NormalizerWrapper::Nmt(bpe), NormalizerUntagged::Precompiled(bpe) => NormalizerWrapper::Precompiled(bpe), diff --git a/tokenizers/src/normalizers/replace.rs b/tokenizers/src/normalizers/replace.rs index cdd4a420..56575748 100644 --- a/tokenizers/src/normalizers/replace.rs +++ b/tokenizers/src/normalizers/replace.rs @@ -46,7 +46,7 @@ impl std::convert::TryFrom for Replace { #[serde(tag = "type", try_from = "ReplaceDeserializer")] pub struct Replace { pattern: ReplacePattern, - content: String, + pub content: String, #[serde(skip)] regex: SysRegex, } diff --git a/tokenizers/src/normalizers/utils.rs b/tokenizers/src/normalizers/utils.rs index a7730a3f..1e33cc79 100644 --- a/tokenizers/src/normalizers/utils.rs +++ b/tokenizers/src/normalizers/utils.rs @@ -16,16 +16,29 @@ impl Sequence { pub fn new(normalizers: Vec) -> Self { Self { normalizers } } +} - pub fn get_normalizers(&self) -> &[NormalizerWrapper] { +impl AsRef<[NormalizerWrapper]> for Sequence { + fn as_ref(&self) -> &[NormalizerWrapper] { &self.normalizers } +} - pub fn get_normalizers_mut(&mut self) -> &mut [NormalizerWrapper] { +impl AsMut<[NormalizerWrapper]> for Sequence { + fn as_mut(&mut self) -> &mut [NormalizerWrapper] { &mut self.normalizers } } +impl IntoIterator for Sequence { + type Item = NormalizerWrapper; + type IntoIter = std::vec::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.normalizers.into_iter() + } +} + impl Normalizer for Sequence { fn normalize(&self, normalized: &mut NormalizedString) -> Result<()> { for normalizer in &self.normalizers { diff --git a/tokenizers/src/pre_tokenizers/metaspace.rs b/tokenizers/src/pre_tokenizers/metaspace.rs index 52b415c9..d821f118 100644 --- a/tokenizers/src/pre_tokenizers/metaspace.rs +++ b/tokenizers/src/pre_tokenizers/metaspace.rs @@ -13,6 +13,12 @@ pub enum PrependScheme { Always, } +impl std::fmt::Display for PrependScheme { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.serialize(f) + } +} + #[derive(Debug, Clone, PartialEq, Serialize, Eq)] /// Replaces all the whitespaces by the provided meta character and then /// splits on this character diff --git a/tokenizers/src/pre_tokenizers/punctuation.rs b/tokenizers/src/pre_tokenizers/punctuation.rs index 0ba7d602..985e7762 100644 --- a/tokenizers/src/pre_tokenizers/punctuation.rs +++ b/tokenizers/src/pre_tokenizers/punctuation.rs @@ -12,7 +12,7 @@ fn is_punc(x: char) -> bool { #[macro_rules_attribute(impl_serde_type!)] pub struct Punctuation { #[serde(default = "default_split")] - behavior: SplitDelimiterBehavior, + pub behavior: SplitDelimiterBehavior, } fn default_split() -> SplitDelimiterBehavior { diff --git a/tokenizers/src/pre_tokenizers/sequence.rs b/tokenizers/src/pre_tokenizers/sequence.rs index 9dcafc67..973b5e8c 100644 --- a/tokenizers/src/pre_tokenizers/sequence.rs +++ b/tokenizers/src/pre_tokenizers/sequence.rs @@ -13,16 +13,29 @@ impl Sequence { pub fn new(pretokenizers: Vec) -> Self { Self { pretokenizers } } +} - pub fn get_pre_tokenizers(&self) -> &[PreTokenizerWrapper] { +impl AsRef<[PreTokenizerWrapper]> for Sequence { + fn as_ref(&self) -> &[PreTokenizerWrapper] { &self.pretokenizers } +} - pub fn get_pre_tokenizers_mut(&mut self) -> &mut [PreTokenizerWrapper] { +impl AsMut<[PreTokenizerWrapper]> for Sequence { + fn as_mut(&mut self) -> &mut [PreTokenizerWrapper] { &mut self.pretokenizers } } +impl IntoIterator for Sequence { + type Item = PreTokenizerWrapper; + type IntoIter = std::vec::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.pretokenizers.into_iter() + } +} + impl PreTokenizer for Sequence { fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> Result<()> { for pretokenizer in &self.pretokenizers { diff --git a/tokenizers/src/pre_tokenizers/split.rs b/tokenizers/src/pre_tokenizers/split.rs index 0e2a9023..5f7362f7 100644 --- a/tokenizers/src/pre_tokenizers/split.rs +++ b/tokenizers/src/pre_tokenizers/split.rs @@ -27,11 +27,11 @@ impl From<&str> for SplitPattern { #[derive(Debug, Serialize)] #[serde(tag = "type")] pub struct Split { - pattern: SplitPattern, + pub pattern: SplitPattern, #[serde(skip)] - regex: SysRegex, - behavior: SplitDelimiterBehavior, - invert: bool, + pub regex: SysRegex, + pub behavior: SplitDelimiterBehavior, + pub invert: bool, } impl<'de> Deserialize<'de> for Split { diff --git a/tokenizers/src/processors/bert.rs b/tokenizers/src/processors/bert.rs index 627f9d18..17939112 100644 --- a/tokenizers/src/processors/bert.rs +++ b/tokenizers/src/processors/bert.rs @@ -6,8 +6,8 @@ use std::iter::FromIterator; #[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq)] #[serde(tag = "type")] pub struct BertProcessing { - sep: (String, u32), - cls: (String, u32), + pub sep: (String, u32), + pub cls: (String, u32), } impl Default for BertProcessing { @@ -23,6 +23,14 @@ impl BertProcessing { pub fn new(sep: (String, u32), cls: (String, u32)) -> Self { Self { sep, cls } } + + pub fn get_sep_copy(&self) -> (String, u32) { + (self.sep.0.clone(), self.sep.1) + } + + pub fn get_cls_copy(&self) -> (String, u32) { + (self.cls.0.clone(), self.cls.1) + } } #[derive(thiserror::Error, Debug)] diff --git a/tokenizers/src/processors/roberta.rs b/tokenizers/src/processors/roberta.rs index 3af9a8d6..5bbc4ea6 100644 --- a/tokenizers/src/processors/roberta.rs +++ b/tokenizers/src/processors/roberta.rs @@ -7,10 +7,10 @@ use std::iter::FromIterator; #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] #[serde(tag = "type")] pub struct RobertaProcessing { - sep: (String, u32), - cls: (String, u32), - trim_offsets: bool, - add_prefix_space: bool, + pub sep: (String, u32), + pub cls: (String, u32), + pub trim_offsets: bool, + pub add_prefix_space: bool, } impl Default for RobertaProcessing { @@ -44,6 +44,14 @@ impl RobertaProcessing { self.add_prefix_space = v; self } + + pub fn get_sep_copy(&self) -> (String, u32) { + (self.sep.0.clone(), self.sep.1) + } + + pub fn get_cls_copy(&self) -> (String, u32) { + (self.cls.0.clone(), self.cls.1) + } } impl PostProcessor for RobertaProcessing { diff --git a/tokenizers/src/processors/sequence.rs b/tokenizers/src/processors/sequence.rs index 66c670ad..5cfb3eb5 100644 --- a/tokenizers/src/processors/sequence.rs +++ b/tokenizers/src/processors/sequence.rs @@ -13,6 +13,39 @@ impl Sequence { pub fn new(processors: Vec) -> Self { Self { processors } } + + pub fn get(&self, index: usize) -> Option<&PostProcessorWrapper> { + self.processors.get(index) + } + + pub fn get_mut(&mut self, index: usize) -> Option<&mut PostProcessorWrapper> { + self.processors.get_mut(index) + } + + pub fn set_mut(&mut self, index: usize, post_proc: PostProcessorWrapper) { + self.processors[index] = post_proc; + } +} + +impl AsRef<[PostProcessorWrapper]> for Sequence { + fn as_ref(&self) -> &[PostProcessorWrapper] { + &self.processors + } +} + +impl AsMut<[PostProcessorWrapper]> for Sequence { + fn as_mut(&mut self) -> &mut [PostProcessorWrapper] { + &mut self.processors + } +} + +impl IntoIterator for Sequence { + type Item = PostProcessorWrapper; + type IntoIter = std::vec::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.processors.into_iter() + } } impl PostProcessor for Sequence { diff --git a/tokenizers/src/processors/template.rs b/tokenizers/src/processors/template.rs index 58462331..74b4fe1c 100644 --- a/tokenizers/src/processors/template.rs +++ b/tokenizers/src/processors/template.rs @@ -338,7 +338,7 @@ impl From> for Tokens { #[builder(build_fn(validate = "Self::validate"))] pub struct TemplateProcessing { #[builder(try_setter, default = "\"$0\".try_into().unwrap()")] - single: Template, + pub single: Template, #[builder(try_setter, default = "\"$A:0 $B:1\".try_into().unwrap()")] pair: Template, #[builder(setter(skip), default = "self.default_added(true)")] @@ -351,6 +351,58 @@ pub struct TemplateProcessing { special_tokens: Tokens, } +impl TemplateProcessing { + // Getter for `single` + pub fn get_single(&self) -> String { + format!("{:?}", self.single) + } + + // Setter for `single` + pub fn set_single(&mut self, single: Template) { + self.single = single; + } + + // Getter for `pair` + pub fn get_pair(&self) -> &Template { + &self.pair + } + + // Setter for `pair` + pub fn set_pair(&mut self, pair: Template) { + self.pair = pair; + } + + // Getter for `added_single` + pub fn get_added_single(&self) -> usize { + self.added_single + } + + // Setter for `added_single` + pub fn set_added_single(&mut self, added_single: usize) { + self.added_single = added_single; + } + + // Getter for `added_pair` + pub fn get_added_pair(&self) -> usize { + self.added_pair + } + + // Setter for `added_pair` + pub fn set_added_pair(&mut self, added_pair: usize) { + self.added_pair = added_pair; + } + + // Getter for `special_tokens` + pub fn get_special_tokens(&self) -> &Tokens { + &self.special_tokens + } + + // Setter for `special_tokens` + pub fn set_special_tokens(&mut self, special_tokens: Tokens) { + self.special_tokens = special_tokens; + } +} + impl From<&str> for TemplateProcessingBuilderError { fn from(e: &str) -> Self { e.to_string().into() diff --git a/tokenizers/src/tokenizer/normalizer.rs b/tokenizers/src/tokenizer/normalizer.rs index d15093fb..432c6cc6 100644 --- a/tokenizers/src/tokenizer/normalizer.rs +++ b/tokenizers/src/tokenizer/normalizer.rs @@ -87,6 +87,12 @@ pub enum SplitDelimiterBehavior { Contiguous, } +impl std::fmt::Display for SplitDelimiterBehavior { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.serialize(f) + } +} + /// A `NormalizedString` takes care of processing an "original" string to modify /// it and obtain a "normalized" string. It keeps both version of the string, /// alignments information between both and provides an interface to retrieve