Python - Add support for custom PreTokenizer

This commit is contained in:
Anthony MOI
2020-09-18 15:56:02 -04:00
committed by Anthony MOI
parent 8d04b22278
commit bd8f25ee2c
2 changed files with 164 additions and 103 deletions

View File

@@ -19,17 +19,18 @@ use tk::{PreTokenizedString, PreTokenizer};
use tokenizers as tk; use tokenizers as tk;
use super::error::ToPyResult; use super::error::ToPyResult;
use super::utils::*;
#[pyclass(dict, module = "tokenizers.pre_tokenizers", name=PreTokenizer)] #[pyclass(dict, module = "tokenizers.pre_tokenizers", name=PreTokenizer)]
#[derive(Clone, Serialize, Deserialize)] #[derive(Clone, Serialize, Deserialize)]
pub struct PyPreTokenizer { pub struct PyPreTokenizer {
#[serde(flatten)] #[serde(flatten)]
pub(crate) pretok: PyPreTokenizerWrapper, pub(crate) pretok: PyPreTokenizerTypeWrapper,
} }
impl PyPreTokenizer { impl PyPreTokenizer {
#[allow(dead_code)] #[allow(dead_code)]
pub(crate) fn new(pretok: PyPreTokenizerWrapper) -> Self { pub(crate) fn new(pretok: PyPreTokenizerTypeWrapper) -> Self {
PyPreTokenizer { pretok } PyPreTokenizer { pretok }
} }
@@ -38,32 +39,38 @@ impl PyPreTokenizer {
let gil = Python::acquire_gil(); let gil = Python::acquire_gil();
let py = gil.python(); let py = gil.python();
Ok(match &self.pretok { Ok(match &self.pretok {
PyPreTokenizerWrapper::Custom(_) => Py::new(py, base)?.into_py(py), PyPreTokenizerTypeWrapper::Sequence(_) => {
PyPreTokenizerWrapper::Sequence(_) => Py::new(py, (PySequence {}, base))?.into_py(py), Py::new(py, (PySequence {}, base))?.into_py(py)
PyPreTokenizerWrapper::Wrapped(inner) => match inner.as_ref() { }
PreTokenizerWrapper::Whitespace(_) => { PyPreTokenizerTypeWrapper::Single(ref inner) => match inner.as_ref() {
Py::new(py, (PyWhitespace {}, base))?.into_py(py) PyPreTokenizerWrapper::Custom(_) => Py::new(py, base)?.into_py(py),
} PyPreTokenizerWrapper::Wrapped(inner) => match inner {
PreTokenizerWrapper::Punctuation(_) => { PreTokenizerWrapper::Whitespace(_) => {
Py::new(py, (PyPunctuation {}, base))?.into_py(py) Py::new(py, (PyWhitespace {}, base))?.into_py(py)
} }
PreTokenizerWrapper::Sequence(_) => Py::new(py, (PySequence {}, base))?.into_py(py), PreTokenizerWrapper::Punctuation(_) => {
PreTokenizerWrapper::Metaspace(_) => { Py::new(py, (PyPunctuation {}, base))?.into_py(py)
Py::new(py, (PyMetaspace {}, base))?.into_py(py) }
} PreTokenizerWrapper::Sequence(_) => {
PreTokenizerWrapper::Delimiter(_) => { Py::new(py, (PySequence {}, base))?.into_py(py)
Py::new(py, (PyCharDelimiterSplit {}, base))?.into_py(py) }
} PreTokenizerWrapper::Metaspace(_) => {
PreTokenizerWrapper::WhitespaceSplit(_) => { Py::new(py, (PyMetaspace {}, base))?.into_py(py)
Py::new(py, (PyWhitespaceSplit {}, base))?.into_py(py) }
} PreTokenizerWrapper::Delimiter(_) => {
PreTokenizerWrapper::ByteLevel(_) => { Py::new(py, (PyCharDelimiterSplit {}, base))?.into_py(py)
Py::new(py, (PyByteLevel {}, base))?.into_py(py) }
} PreTokenizerWrapper::WhitespaceSplit(_) => {
PreTokenizerWrapper::BertPreTokenizer(_) => { Py::new(py, (PyWhitespaceSplit {}, base))?.into_py(py)
Py::new(py, (PyBertPreTokenizer {}, base))?.into_py(py) }
} PreTokenizerWrapper::ByteLevel(_) => {
PreTokenizerWrapper::Digits(_) => Py::new(py, (PyDigits {}, base))?.into_py(py), Py::new(py, (PyByteLevel {}, base))?.into_py(py)
}
PreTokenizerWrapper::BertPreTokenizer(_) => {
Py::new(py, (PyBertPreTokenizer {}, base))?.into_py(py)
}
PreTokenizerWrapper::Digits(_) => Py::new(py, (PyDigits {}, base))?.into_py(py),
},
}, },
}) })
} }
@@ -77,13 +84,12 @@ impl PreTokenizer for PyPreTokenizer {
#[pymethods] #[pymethods]
impl PyPreTokenizer { impl PyPreTokenizer {
// #[staticmethod] #[staticmethod]
// fn custom(pretok: PyObject) -> PyResult<Self> { fn custom(pretok: PyObject) -> PyResult<Self> {
// let py_pretok = CustomPreTokenizer::new(pretok)?; Ok(PyPreTokenizer {
// Ok(PyPreTokenizer { pretok: PyPreTokenizerWrapper::Custom(CustomPreTokenizer::new(pretok)).into(),
// pretok: Arc::new(py_pretok), })
// }) }
// }
fn __getstate__(&self, py: Python) -> PyResult<PyObject> { fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
let data = serde_json::to_string(&self.pretok).map_err(|e| { let data = serde_json::to_string(&self.pretok).map_err(|e| {
@@ -227,16 +233,15 @@ impl PySequence {
for n in pre_tokenizers.iter() { for n in pre_tokenizers.iter() {
let pretokenizer: PyRef<PyPreTokenizer> = n.extract()?; let pretokenizer: PyRef<PyPreTokenizer> = n.extract()?;
match &pretokenizer.pretok { match &pretokenizer.pretok {
PyPreTokenizerWrapper::Sequence(inner) => sequence.extend(inner.iter().cloned()), PyPreTokenizerTypeWrapper::Sequence(inner) => {
PyPreTokenizerWrapper::Wrapped(inner) => sequence.push(inner.clone()), sequence.extend(inner.iter().cloned())
PyPreTokenizerWrapper::Custom(_) => unreachable!( }
"Custom pretokenizers are currently disabled, how did you get here?" PyPreTokenizerTypeWrapper::Single(inner) => sequence.push(inner.clone()),
),
} }
} }
Ok(( Ok((
PySequence {}, PySequence {},
PyPreTokenizer::new(PyPreTokenizerWrapper::Sequence(sequence)), PyPreTokenizer::new(PyPreTokenizerTypeWrapper::Sequence(sequence)),
)) ))
} }
@@ -289,54 +294,35 @@ impl PyDigits {
} }
} }
// this is not accessible in python since the custom method is disabled. #[derive(Clone)]
#[allow(dead_code)]
pub(crate) struct CustomPreTokenizer { pub(crate) struct CustomPreTokenizer {
class: PyObject, inner: PyObject,
} }
impl CustomPreTokenizer { impl CustomPreTokenizer {
#[allow(dead_code)] pub fn new(inner: PyObject) -> Self {
pub fn new(class: PyObject) -> PyResult<Self> { Self { inner }
Ok(CustomPreTokenizer { class }) }
}
impl tk::tokenizer::PreTokenizer for CustomPreTokenizer {
fn pre_tokenize(&self, sentence: &mut PreTokenizedString) -> tk::Result<()> {
Python::with_gil(|py| {
let pretok = PyPreTokenizedStringRefMut::new(sentence);
let py_pretok = self.inner.as_ref(py);
py_pretok.call_method("pre_tokenize", (pretok.get(),), None)?;
Ok(())
})
} }
} }
// impl tk::tokenizer::PreTokenizer for CustomPreTokenizer {
// fn pre_tokenize(&self, sentence: &mut PreTokenizedString) -> tk::Result<()> {
// let gil = Python::acquire_gil();
// let py = gil.python();
//
// let args = PyTuple::new(py, &[sentence.get()]);
// match self.class.call_method(py, "pre_tokenize", args, None) {
// Ok(res) => Ok(res
// .cast_as::<PyList>(py)
// .map_err(|_| {
// PyError::from("`pre_tokenize is expected to return a List[(str, (uint, uint))]")
// })?
// .extract::<Vec<(String, Offsets)>>()
// .map_err(|_| {
// PyError::from(
// "`pre_tokenize` is expected to return a List[(str, (uint, uint))]",
// )
// })?),
// Err(e) => {
// e.print(py);
// Err(Box::new(PyError::from(
// "Error while calling `pre_tokenize`",
// )))
// }
// }
// }
// }
//
impl Serialize for CustomPreTokenizer { impl Serialize for CustomPreTokenizer {
fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error> fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
where where
S: Serializer, S: Serializer,
{ {
Err(serde::ser::Error::custom( Err(serde::ser::Error::custom(
"Custom PyPreTokenizer cannot be serialized", "Custom PreTokenizer cannot be serialized",
)) ))
} }
} }
@@ -346,16 +332,17 @@ impl<'de> Deserialize<'de> for CustomPreTokenizer {
where where
D: Deserializer<'de>, D: Deserializer<'de>,
{ {
Err(serde::de::Error::custom("PyDecoder cannot be deserialized")) Err(serde::de::Error::custom(
"Custom PreTokenizer cannot be deserialized",
))
} }
} }
#[derive(Clone, Deserialize)] #[derive(Clone, Deserialize)]
#[serde(untagged)] #[serde(untagged)]
pub(crate) enum PyPreTokenizerWrapper { pub(crate) enum PyPreTokenizerWrapper {
Sequence(Vec<Arc<PreTokenizerWrapper>>), Custom(CustomPreTokenizer),
Custom(Arc<CustomPreTokenizer>), Wrapped(PreTokenizerWrapper),
Wrapped(Arc<PreTokenizerWrapper>),
} }
impl Serialize for PyPreTokenizerWrapper { impl Serialize for PyPreTokenizerWrapper {
@@ -364,14 +351,32 @@ impl Serialize for PyPreTokenizerWrapper {
S: Serializer, S: Serializer,
{ {
match self { match self {
PyPreTokenizerWrapper::Sequence(seq) => { PyPreTokenizerWrapper::Wrapped(inner) => inner.serialize(serializer),
PyPreTokenizerWrapper::Custom(inner) => inner.serialize(serializer),
}
}
}
#[derive(Clone, Deserialize)]
#[serde(untagged)]
pub(crate) enum PyPreTokenizerTypeWrapper {
Sequence(Vec<Arc<PyPreTokenizerWrapper>>),
Single(Arc<PyPreTokenizerWrapper>),
}
impl Serialize for PyPreTokenizerTypeWrapper {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
match self {
PyPreTokenizerTypeWrapper::Sequence(seq) => {
let mut ser = serializer.serialize_struct("Sequence", 2)?; let mut ser = serializer.serialize_struct("Sequence", 2)?;
ser.serialize_field("type", "Sequence")?; ser.serialize_field("type", "Sequence")?;
ser.serialize_field("pretokenizers", seq)?; ser.serialize_field("pretokenizers", seq)?;
ser.end() ser.end()
} }
PyPreTokenizerWrapper::Wrapped(inner) => inner.serialize(serializer), PyPreTokenizerTypeWrapper::Single(inner) => inner.serialize(serializer),
PyPreTokenizerWrapper::Custom(inner) => inner.serialize(serializer),
} }
} }
} }
@@ -380,8 +385,17 @@ impl<I> From<I> for PyPreTokenizerWrapper
where where
I: Into<PreTokenizerWrapper>, I: Into<PreTokenizerWrapper>,
{ {
fn from(norm: I) -> Self { fn from(pretok: I) -> Self {
PyPreTokenizerWrapper::Wrapped(Arc::new(norm.into())) PyPreTokenizerWrapper::Wrapped(pretok.into())
}
}
impl<I> From<I> for PyPreTokenizerTypeWrapper
where
I: Into<PyPreTokenizerWrapper>,
{
fn from(pretok: I) -> Self {
PyPreTokenizerTypeWrapper::Single(Arc::new(pretok.into()))
} }
} }
@@ -396,28 +410,36 @@ where
} }
} }
impl PreTokenizer for PyPreTokenizerWrapper { impl PreTokenizer for PyPreTokenizerTypeWrapper {
fn pre_tokenize(&self, normalized: &mut PreTokenizedString) -> tk::Result<()> { fn pre_tokenize(&self, pretok: &mut PreTokenizedString) -> tk::Result<()> {
match self { match self {
PyPreTokenizerWrapper::Wrapped(inner) => inner.pre_tokenize(normalized), PyPreTokenizerTypeWrapper::Single(inner) => inner.pre_tokenize(pretok),
PyPreTokenizerWrapper::Sequence(inner) => { PyPreTokenizerTypeWrapper::Sequence(inner) => {
inner.iter().map(|n| n.pre_tokenize(normalized)).collect() inner.iter().map(|n| n.pre_tokenize(pretok)).collect()
}
PyPreTokenizerWrapper::Custom(_) => {
unreachable!("Custom pretokenizers are currently disabled, how did you get here?")
} }
} }
} }
} }
impl PreTokenizer for PyPreTokenizerWrapper {
fn pre_tokenize(&self, pretok: &mut PreTokenizedString) -> tk::Result<()> {
match self {
PyPreTokenizerWrapper::Wrapped(inner) => inner.pre_tokenize(pretok),
PyPreTokenizerWrapper::Custom(inner) => inner.pre_tokenize(pretok),
}
}
}
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use pyo3::prelude::*; use pyo3::prelude::*;
use tk::pre_tokenizers::whitespace::Whitespace; use tk::pre_tokenizers::sequence::Sequence;
use tk::pre_tokenizers::whitespace::{Whitespace, WhitespaceSplit};
use tk::pre_tokenizers::PreTokenizerWrapper; use tk::pre_tokenizers::PreTokenizerWrapper;
use crate::pre_tokenizers::{CustomPreTokenizer, PyPreTokenizer, PyPreTokenizerWrapper}; use crate::pre_tokenizers::{
use std::sync::Arc; CustomPreTokenizer, PyPreTokenizer, PyPreTokenizerTypeWrapper, PyPreTokenizerWrapper,
};
#[test] #[test]
fn get_subtype() { fn get_subtype() {
@@ -439,18 +461,34 @@ mod test {
assert_eq!(py_ser, rs_ser); assert_eq!(py_ser, rs_ser);
let py_pretok: PyPreTokenizer = serde_json::from_str(&rs_ser).unwrap(); let py_pretok: PyPreTokenizer = serde_json::from_str(&rs_ser).unwrap();
match py_pretok.pretok { match py_pretok.pretok {
PyPreTokenizerWrapper::Wrapped(wsp) => match wsp.as_ref() { PyPreTokenizerTypeWrapper::Single(inner) => match inner.as_ref() {
PreTokenizerWrapper::Whitespace(_) => {} PyPreTokenizerWrapper::Wrapped(PreTokenizerWrapper::Whitespace(_)) => {}
_ => panic!("Expected Whitespace"), _ => panic!("Expected Whitespace"),
}, },
_ => panic!("Expected wrapped, not custom."), _ => panic!("Expected wrapped, not custom."),
} }
let gil = Python::acquire_gil();
let py = gil.python();
let py_wsp = PyPreTokenizer::new(Whitespace::default().into());
let obj: PyObject = Py::new(py, py_wsp).unwrap().into_py(py);
let py_seq: PyPreTokenizerWrapper = let py_seq: PyPreTokenizerWrapper =
PyPreTokenizerWrapper::Custom(Arc::new(CustomPreTokenizer::new(obj).unwrap())); Sequence::new(vec![Whitespace::default().into(), WhitespaceSplit.into()]).into();
let py_wrapper_ser = serde_json::to_string(&py_seq).unwrap();
let rs_wrapped = PreTokenizerWrapper::Sequence(Sequence::new(vec![
Whitespace::default().into(),
WhitespaceSplit.into(),
]));
let rs_ser = serde_json::to_string(&rs_wrapped).unwrap();
assert_eq!(py_wrapper_ser, rs_ser);
let py_seq = PyPreTokenizer::new(py_seq.into());
let py_ser = serde_json::to_string(&py_seq).unwrap();
assert_eq!(py_wrapper_ser, py_ser);
let obj = Python::with_gil(|py| {
let py_wsp = PyPreTokenizer::new(Whitespace::default().into());
let obj: PyObject = Py::new(py, py_wsp).unwrap().into_py(py);
obj
});
let py_seq: PyPreTokenizerWrapper =
PyPreTokenizerWrapper::Custom(CustomPreTokenizer::new(obj));
assert!(serde_json::to_string(&py_seq).is_err()); assert!(serde_json::to_string(&py_seq).is_err());
} }
} }

View File

@@ -119,3 +119,26 @@ class TestDigits:
assert isinstance(Digits(True), Digits) assert isinstance(Digits(True), Digits)
assert isinstance(Digits(False), Digits) assert isinstance(Digits(False), Digits)
assert isinstance(pickle.loads(pickle.dumps(Digits())), Digits) assert isinstance(pickle.loads(pickle.dumps(Digits())), Digits)
class TestCustomPreTokenizer:
class BadCustomPretok:
def pre_tokenize(self, pretok, wrong):
pass
class GoodCustomPretok:
def split(self, n, normalized):
return [normalized, normalized]
def pre_tokenize(self, pretok):
pretok.split(self.split)
def test_instantiate(self):
bad = PreTokenizer.custom(TestCustomPreTokenizer.BadCustomPretok())
good = PreTokenizer.custom(TestCustomPreTokenizer.GoodCustomPretok())
assert isinstance(bad, PreTokenizer)
assert isinstance(good, PreTokenizer)
with pytest.raises(Exception, match="TypeError: pre_tokenize()"):
bad.pre_tokenize("Hey there!")
assert good.pre_tokenize("Hey there!") == [("Hey there!", (0, 10)), ("Hey there!", (0, 10))]