diff --git a/bindings/python/src/pre_tokenizers.rs b/bindings/python/src/pre_tokenizers.rs index fb0edb71..1b12c9b9 100644 --- a/bindings/python/src/pre_tokenizers.rs +++ b/bindings/python/src/pre_tokenizers.rs @@ -19,17 +19,18 @@ use tk::{PreTokenizedString, PreTokenizer}; use tokenizers as tk; use super::error::ToPyResult; +use super::utils::*; #[pyclass(dict, module = "tokenizers.pre_tokenizers", name=PreTokenizer)] #[derive(Clone, Serialize, Deserialize)] pub struct PyPreTokenizer { #[serde(flatten)] - pub(crate) pretok: PyPreTokenizerWrapper, + pub(crate) pretok: PyPreTokenizerTypeWrapper, } impl PyPreTokenizer { #[allow(dead_code)] - pub(crate) fn new(pretok: PyPreTokenizerWrapper) -> Self { + pub(crate) fn new(pretok: PyPreTokenizerTypeWrapper) -> Self { PyPreTokenizer { pretok } } @@ -38,32 +39,38 @@ impl PyPreTokenizer { let gil = Python::acquire_gil(); let py = gil.python(); Ok(match &self.pretok { - PyPreTokenizerWrapper::Custom(_) => Py::new(py, base)?.into_py(py), - PyPreTokenizerWrapper::Sequence(_) => Py::new(py, (PySequence {}, base))?.into_py(py), - PyPreTokenizerWrapper::Wrapped(inner) => match inner.as_ref() { - PreTokenizerWrapper::Whitespace(_) => { - Py::new(py, (PyWhitespace {}, base))?.into_py(py) - } - PreTokenizerWrapper::Punctuation(_) => { - Py::new(py, (PyPunctuation {}, base))?.into_py(py) - } - PreTokenizerWrapper::Sequence(_) => Py::new(py, (PySequence {}, base))?.into_py(py), - PreTokenizerWrapper::Metaspace(_) => { - Py::new(py, (PyMetaspace {}, base))?.into_py(py) - } - PreTokenizerWrapper::Delimiter(_) => { - Py::new(py, (PyCharDelimiterSplit {}, base))?.into_py(py) - } - PreTokenizerWrapper::WhitespaceSplit(_) => { - Py::new(py, (PyWhitespaceSplit {}, base))?.into_py(py) - } - PreTokenizerWrapper::ByteLevel(_) => { - Py::new(py, (PyByteLevel {}, base))?.into_py(py) - } - PreTokenizerWrapper::BertPreTokenizer(_) => { - Py::new(py, (PyBertPreTokenizer {}, base))?.into_py(py) - } - PreTokenizerWrapper::Digits(_) => Py::new(py, (PyDigits {}, base))?.into_py(py), + PyPreTokenizerTypeWrapper::Sequence(_) => { + Py::new(py, (PySequence {}, base))?.into_py(py) + } + PyPreTokenizerTypeWrapper::Single(ref inner) => match inner.as_ref() { + PyPreTokenizerWrapper::Custom(_) => Py::new(py, base)?.into_py(py), + PyPreTokenizerWrapper::Wrapped(inner) => match inner { + PreTokenizerWrapper::Whitespace(_) => { + Py::new(py, (PyWhitespace {}, base))?.into_py(py) + } + PreTokenizerWrapper::Punctuation(_) => { + Py::new(py, (PyPunctuation {}, base))?.into_py(py) + } + PreTokenizerWrapper::Sequence(_) => { + Py::new(py, (PySequence {}, base))?.into_py(py) + } + PreTokenizerWrapper::Metaspace(_) => { + Py::new(py, (PyMetaspace {}, base))?.into_py(py) + } + PreTokenizerWrapper::Delimiter(_) => { + Py::new(py, (PyCharDelimiterSplit {}, base))?.into_py(py) + } + PreTokenizerWrapper::WhitespaceSplit(_) => { + Py::new(py, (PyWhitespaceSplit {}, base))?.into_py(py) + } + PreTokenizerWrapper::ByteLevel(_) => { + Py::new(py, (PyByteLevel {}, base))?.into_py(py) + } + PreTokenizerWrapper::BertPreTokenizer(_) => { + Py::new(py, (PyBertPreTokenizer {}, base))?.into_py(py) + } + PreTokenizerWrapper::Digits(_) => Py::new(py, (PyDigits {}, base))?.into_py(py), + }, }, }) } @@ -77,13 +84,12 @@ impl PreTokenizer for PyPreTokenizer { #[pymethods] impl PyPreTokenizer { - // #[staticmethod] - // fn custom(pretok: PyObject) -> PyResult { - // let py_pretok = CustomPreTokenizer::new(pretok)?; - // Ok(PyPreTokenizer { - // pretok: Arc::new(py_pretok), - // }) - // } + #[staticmethod] + fn custom(pretok: PyObject) -> PyResult { + Ok(PyPreTokenizer { + pretok: PyPreTokenizerWrapper::Custom(CustomPreTokenizer::new(pretok)).into(), + }) + } fn __getstate__(&self, py: Python) -> PyResult { let data = serde_json::to_string(&self.pretok).map_err(|e| { @@ -227,16 +233,15 @@ impl PySequence { for n in pre_tokenizers.iter() { let pretokenizer: PyRef = n.extract()?; match &pretokenizer.pretok { - PyPreTokenizerWrapper::Sequence(inner) => sequence.extend(inner.iter().cloned()), - PyPreTokenizerWrapper::Wrapped(inner) => sequence.push(inner.clone()), - PyPreTokenizerWrapper::Custom(_) => unreachable!( - "Custom pretokenizers are currently disabled, how did you get here?" - ), + PyPreTokenizerTypeWrapper::Sequence(inner) => { + sequence.extend(inner.iter().cloned()) + } + PyPreTokenizerTypeWrapper::Single(inner) => sequence.push(inner.clone()), } } Ok(( PySequence {}, - PyPreTokenizer::new(PyPreTokenizerWrapper::Sequence(sequence)), + PyPreTokenizer::new(PyPreTokenizerTypeWrapper::Sequence(sequence)), )) } @@ -289,54 +294,35 @@ impl PyDigits { } } -// this is not accessible in python since the custom method is disabled. -#[allow(dead_code)] +#[derive(Clone)] pub(crate) struct CustomPreTokenizer { - class: PyObject, + inner: PyObject, } impl CustomPreTokenizer { - #[allow(dead_code)] - pub fn new(class: PyObject) -> PyResult { - Ok(CustomPreTokenizer { class }) + pub fn new(inner: PyObject) -> Self { + Self { inner } + } +} + +impl tk::tokenizer::PreTokenizer for CustomPreTokenizer { + fn pre_tokenize(&self, sentence: &mut PreTokenizedString) -> tk::Result<()> { + Python::with_gil(|py| { + let pretok = PyPreTokenizedStringRefMut::new(sentence); + let py_pretok = self.inner.as_ref(py); + py_pretok.call_method("pre_tokenize", (pretok.get(),), None)?; + Ok(()) + }) } } -// impl tk::tokenizer::PreTokenizer for CustomPreTokenizer { -// fn pre_tokenize(&self, sentence: &mut PreTokenizedString) -> tk::Result<()> { -// let gil = Python::acquire_gil(); -// let py = gil.python(); -// -// let args = PyTuple::new(py, &[sentence.get()]); -// match self.class.call_method(py, "pre_tokenize", args, None) { -// Ok(res) => Ok(res -// .cast_as::(py) -// .map_err(|_| { -// PyError::from("`pre_tokenize is expected to return a List[(str, (uint, uint))]") -// })? -// .extract::>() -// .map_err(|_| { -// PyError::from( -// "`pre_tokenize` is expected to return a List[(str, (uint, uint))]", -// ) -// })?), -// Err(e) => { -// e.print(py); -// Err(Box::new(PyError::from( -// "Error while calling `pre_tokenize`", -// ))) -// } -// } -// } -// } -// impl Serialize for CustomPreTokenizer { fn serialize(&self, _serializer: S) -> Result where S: Serializer, { Err(serde::ser::Error::custom( - "Custom PyPreTokenizer cannot be serialized", + "Custom PreTokenizer cannot be serialized", )) } } @@ -346,16 +332,17 @@ impl<'de> Deserialize<'de> for CustomPreTokenizer { where D: Deserializer<'de>, { - Err(serde::de::Error::custom("PyDecoder cannot be deserialized")) + Err(serde::de::Error::custom( + "Custom PreTokenizer cannot be deserialized", + )) } } #[derive(Clone, Deserialize)] #[serde(untagged)] pub(crate) enum PyPreTokenizerWrapper { - Sequence(Vec>), - Custom(Arc), - Wrapped(Arc), + Custom(CustomPreTokenizer), + Wrapped(PreTokenizerWrapper), } impl Serialize for PyPreTokenizerWrapper { @@ -364,14 +351,32 @@ impl Serialize for PyPreTokenizerWrapper { S: Serializer, { match self { - PyPreTokenizerWrapper::Sequence(seq) => { + PyPreTokenizerWrapper::Wrapped(inner) => inner.serialize(serializer), + PyPreTokenizerWrapper::Custom(inner) => inner.serialize(serializer), + } + } +} + +#[derive(Clone, Deserialize)] +#[serde(untagged)] +pub(crate) enum PyPreTokenizerTypeWrapper { + Sequence(Vec>), + Single(Arc), +} + +impl Serialize for PyPreTokenizerTypeWrapper { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + match self { + PyPreTokenizerTypeWrapper::Sequence(seq) => { let mut ser = serializer.serialize_struct("Sequence", 2)?; ser.serialize_field("type", "Sequence")?; ser.serialize_field("pretokenizers", seq)?; ser.end() } - PyPreTokenizerWrapper::Wrapped(inner) => inner.serialize(serializer), - PyPreTokenizerWrapper::Custom(inner) => inner.serialize(serializer), + PyPreTokenizerTypeWrapper::Single(inner) => inner.serialize(serializer), } } } @@ -380,8 +385,17 @@ impl From for PyPreTokenizerWrapper where I: Into, { - fn from(norm: I) -> Self { - PyPreTokenizerWrapper::Wrapped(Arc::new(norm.into())) + fn from(pretok: I) -> Self { + PyPreTokenizerWrapper::Wrapped(pretok.into()) + } +} + +impl From for PyPreTokenizerTypeWrapper +where + I: Into, +{ + fn from(pretok: I) -> Self { + PyPreTokenizerTypeWrapper::Single(Arc::new(pretok.into())) } } @@ -396,28 +410,36 @@ where } } -impl PreTokenizer for PyPreTokenizerWrapper { - fn pre_tokenize(&self, normalized: &mut PreTokenizedString) -> tk::Result<()> { +impl PreTokenizer for PyPreTokenizerTypeWrapper { + fn pre_tokenize(&self, pretok: &mut PreTokenizedString) -> tk::Result<()> { match self { - PyPreTokenizerWrapper::Wrapped(inner) => inner.pre_tokenize(normalized), - PyPreTokenizerWrapper::Sequence(inner) => { - inner.iter().map(|n| n.pre_tokenize(normalized)).collect() - } - PyPreTokenizerWrapper::Custom(_) => { - unreachable!("Custom pretokenizers are currently disabled, how did you get here?") + PyPreTokenizerTypeWrapper::Single(inner) => inner.pre_tokenize(pretok), + PyPreTokenizerTypeWrapper::Sequence(inner) => { + inner.iter().map(|n| n.pre_tokenize(pretok)).collect() } } } } +impl PreTokenizer for PyPreTokenizerWrapper { + fn pre_tokenize(&self, pretok: &mut PreTokenizedString) -> tk::Result<()> { + match self { + PyPreTokenizerWrapper::Wrapped(inner) => inner.pre_tokenize(pretok), + PyPreTokenizerWrapper::Custom(inner) => inner.pre_tokenize(pretok), + } + } +} + #[cfg(test)] mod test { use pyo3::prelude::*; - use tk::pre_tokenizers::whitespace::Whitespace; + use tk::pre_tokenizers::sequence::Sequence; + use tk::pre_tokenizers::whitespace::{Whitespace, WhitespaceSplit}; use tk::pre_tokenizers::PreTokenizerWrapper; - use crate::pre_tokenizers::{CustomPreTokenizer, PyPreTokenizer, PyPreTokenizerWrapper}; - use std::sync::Arc; + use crate::pre_tokenizers::{ + CustomPreTokenizer, PyPreTokenizer, PyPreTokenizerTypeWrapper, PyPreTokenizerWrapper, + }; #[test] fn get_subtype() { @@ -439,18 +461,34 @@ mod test { assert_eq!(py_ser, rs_ser); let py_pretok: PyPreTokenizer = serde_json::from_str(&rs_ser).unwrap(); match py_pretok.pretok { - PyPreTokenizerWrapper::Wrapped(wsp) => match wsp.as_ref() { - PreTokenizerWrapper::Whitespace(_) => {} + PyPreTokenizerTypeWrapper::Single(inner) => match inner.as_ref() { + PyPreTokenizerWrapper::Wrapped(PreTokenizerWrapper::Whitespace(_)) => {} _ => panic!("Expected Whitespace"), }, _ => panic!("Expected wrapped, not custom."), } - let gil = Python::acquire_gil(); - let py = gil.python(); - let py_wsp = PyPreTokenizer::new(Whitespace::default().into()); - let obj: PyObject = Py::new(py, py_wsp).unwrap().into_py(py); + let py_seq: PyPreTokenizerWrapper = - PyPreTokenizerWrapper::Custom(Arc::new(CustomPreTokenizer::new(obj).unwrap())); + Sequence::new(vec![Whitespace::default().into(), WhitespaceSplit.into()]).into(); + let py_wrapper_ser = serde_json::to_string(&py_seq).unwrap(); + let rs_wrapped = PreTokenizerWrapper::Sequence(Sequence::new(vec![ + Whitespace::default().into(), + WhitespaceSplit.into(), + ])); + let rs_ser = serde_json::to_string(&rs_wrapped).unwrap(); + assert_eq!(py_wrapper_ser, rs_ser); + + let py_seq = PyPreTokenizer::new(py_seq.into()); + let py_ser = serde_json::to_string(&py_seq).unwrap(); + assert_eq!(py_wrapper_ser, py_ser); + + let obj = Python::with_gil(|py| { + let py_wsp = PyPreTokenizer::new(Whitespace::default().into()); + let obj: PyObject = Py::new(py, py_wsp).unwrap().into_py(py); + obj + }); + let py_seq: PyPreTokenizerWrapper = + PyPreTokenizerWrapper::Custom(CustomPreTokenizer::new(obj)); assert!(serde_json::to_string(&py_seq).is_err()); } } diff --git a/bindings/python/tests/bindings/test_pre_tokenizers.py b/bindings/python/tests/bindings/test_pre_tokenizers.py index 5345352c..2a9a629b 100644 --- a/bindings/python/tests/bindings/test_pre_tokenizers.py +++ b/bindings/python/tests/bindings/test_pre_tokenizers.py @@ -119,3 +119,26 @@ class TestDigits: assert isinstance(Digits(True), Digits) assert isinstance(Digits(False), Digits) assert isinstance(pickle.loads(pickle.dumps(Digits())), Digits) + + +class TestCustomPreTokenizer: + class BadCustomPretok: + def pre_tokenize(self, pretok, wrong): + pass + + class GoodCustomPretok: + def split(self, n, normalized): + return [normalized, normalized] + + def pre_tokenize(self, pretok): + pretok.split(self.split) + + def test_instantiate(self): + bad = PreTokenizer.custom(TestCustomPreTokenizer.BadCustomPretok()) + good = PreTokenizer.custom(TestCustomPreTokenizer.GoodCustomPretok()) + + assert isinstance(bad, PreTokenizer) + assert isinstance(good, PreTokenizer) + with pytest.raises(Exception, match="TypeError: pre_tokenize()"): + bad.pre_tokenize("Hey there!") + assert good.pre_tokenize("Hey there!") == [("Hey there!", (0, 10)), ("Hey there!", (0, 10))]