Python - Add support for custom PreTokenizer

This commit is contained in:
Anthony MOI
2020-09-18 15:56:02 -04:00
committed by Anthony MOI
parent 8d04b22278
commit bd8f25ee2c
2 changed files with 164 additions and 103 deletions

View File

@@ -19,17 +19,18 @@ use tk::{PreTokenizedString, PreTokenizer};
use tokenizers as tk;
use super::error::ToPyResult;
use super::utils::*;
#[pyclass(dict, module = "tokenizers.pre_tokenizers", name=PreTokenizer)]
#[derive(Clone, Serialize, Deserialize)]
pub struct PyPreTokenizer {
#[serde(flatten)]
pub(crate) pretok: PyPreTokenizerWrapper,
pub(crate) pretok: PyPreTokenizerTypeWrapper,
}
impl PyPreTokenizer {
#[allow(dead_code)]
pub(crate) fn new(pretok: PyPreTokenizerWrapper) -> Self {
pub(crate) fn new(pretok: PyPreTokenizerTypeWrapper) -> Self {
PyPreTokenizer { pretok }
}
@@ -38,32 +39,38 @@ impl PyPreTokenizer {
let gil = Python::acquire_gil();
let py = gil.python();
Ok(match &self.pretok {
PyPreTokenizerWrapper::Custom(_) => Py::new(py, base)?.into_py(py),
PyPreTokenizerWrapper::Sequence(_) => Py::new(py, (PySequence {}, base))?.into_py(py),
PyPreTokenizerWrapper::Wrapped(inner) => match inner.as_ref() {
PreTokenizerWrapper::Whitespace(_) => {
Py::new(py, (PyWhitespace {}, base))?.into_py(py)
}
PreTokenizerWrapper::Punctuation(_) => {
Py::new(py, (PyPunctuation {}, base))?.into_py(py)
}
PreTokenizerWrapper::Sequence(_) => Py::new(py, (PySequence {}, base))?.into_py(py),
PreTokenizerWrapper::Metaspace(_) => {
Py::new(py, (PyMetaspace {}, base))?.into_py(py)
}
PreTokenizerWrapper::Delimiter(_) => {
Py::new(py, (PyCharDelimiterSplit {}, base))?.into_py(py)
}
PreTokenizerWrapper::WhitespaceSplit(_) => {
Py::new(py, (PyWhitespaceSplit {}, base))?.into_py(py)
}
PreTokenizerWrapper::ByteLevel(_) => {
Py::new(py, (PyByteLevel {}, base))?.into_py(py)
}
PreTokenizerWrapper::BertPreTokenizer(_) => {
Py::new(py, (PyBertPreTokenizer {}, base))?.into_py(py)
}
PreTokenizerWrapper::Digits(_) => Py::new(py, (PyDigits {}, base))?.into_py(py),
PyPreTokenizerTypeWrapper::Sequence(_) => {
Py::new(py, (PySequence {}, base))?.into_py(py)
}
PyPreTokenizerTypeWrapper::Single(ref inner) => match inner.as_ref() {
PyPreTokenizerWrapper::Custom(_) => Py::new(py, base)?.into_py(py),
PyPreTokenizerWrapper::Wrapped(inner) => match inner {
PreTokenizerWrapper::Whitespace(_) => {
Py::new(py, (PyWhitespace {}, base))?.into_py(py)
}
PreTokenizerWrapper::Punctuation(_) => {
Py::new(py, (PyPunctuation {}, base))?.into_py(py)
}
PreTokenizerWrapper::Sequence(_) => {
Py::new(py, (PySequence {}, base))?.into_py(py)
}
PreTokenizerWrapper::Metaspace(_) => {
Py::new(py, (PyMetaspace {}, base))?.into_py(py)
}
PreTokenizerWrapper::Delimiter(_) => {
Py::new(py, (PyCharDelimiterSplit {}, base))?.into_py(py)
}
PreTokenizerWrapper::WhitespaceSplit(_) => {
Py::new(py, (PyWhitespaceSplit {}, base))?.into_py(py)
}
PreTokenizerWrapper::ByteLevel(_) => {
Py::new(py, (PyByteLevel {}, base))?.into_py(py)
}
PreTokenizerWrapper::BertPreTokenizer(_) => {
Py::new(py, (PyBertPreTokenizer {}, base))?.into_py(py)
}
PreTokenizerWrapper::Digits(_) => Py::new(py, (PyDigits {}, base))?.into_py(py),
},
},
})
}
@@ -77,13 +84,12 @@ impl PreTokenizer for PyPreTokenizer {
#[pymethods]
impl PyPreTokenizer {
// #[staticmethod]
// fn custom(pretok: PyObject) -> PyResult<Self> {
// let py_pretok = CustomPreTokenizer::new(pretok)?;
// Ok(PyPreTokenizer {
// pretok: Arc::new(py_pretok),
// })
// }
#[staticmethod]
fn custom(pretok: PyObject) -> PyResult<Self> {
Ok(PyPreTokenizer {
pretok: PyPreTokenizerWrapper::Custom(CustomPreTokenizer::new(pretok)).into(),
})
}
fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
let data = serde_json::to_string(&self.pretok).map_err(|e| {
@@ -227,16 +233,15 @@ impl PySequence {
for n in pre_tokenizers.iter() {
let pretokenizer: PyRef<PyPreTokenizer> = n.extract()?;
match &pretokenizer.pretok {
PyPreTokenizerWrapper::Sequence(inner) => sequence.extend(inner.iter().cloned()),
PyPreTokenizerWrapper::Wrapped(inner) => sequence.push(inner.clone()),
PyPreTokenizerWrapper::Custom(_) => unreachable!(
"Custom pretokenizers are currently disabled, how did you get here?"
),
PyPreTokenizerTypeWrapper::Sequence(inner) => {
sequence.extend(inner.iter().cloned())
}
PyPreTokenizerTypeWrapper::Single(inner) => sequence.push(inner.clone()),
}
}
Ok((
PySequence {},
PyPreTokenizer::new(PyPreTokenizerWrapper::Sequence(sequence)),
PyPreTokenizer::new(PyPreTokenizerTypeWrapper::Sequence(sequence)),
))
}
@@ -289,54 +294,35 @@ impl PyDigits {
}
}
// this is not accessible in python since the custom method is disabled.
#[allow(dead_code)]
#[derive(Clone)]
pub(crate) struct CustomPreTokenizer {
class: PyObject,
inner: PyObject,
}
impl CustomPreTokenizer {
#[allow(dead_code)]
pub fn new(class: PyObject) -> PyResult<Self> {
Ok(CustomPreTokenizer { class })
pub fn new(inner: PyObject) -> Self {
Self { inner }
}
}
impl tk::tokenizer::PreTokenizer for CustomPreTokenizer {
fn pre_tokenize(&self, sentence: &mut PreTokenizedString) -> tk::Result<()> {
Python::with_gil(|py| {
let pretok = PyPreTokenizedStringRefMut::new(sentence);
let py_pretok = self.inner.as_ref(py);
py_pretok.call_method("pre_tokenize", (pretok.get(),), None)?;
Ok(())
})
}
}
// impl tk::tokenizer::PreTokenizer for CustomPreTokenizer {
// fn pre_tokenize(&self, sentence: &mut PreTokenizedString) -> tk::Result<()> {
// let gil = Python::acquire_gil();
// let py = gil.python();
//
// let args = PyTuple::new(py, &[sentence.get()]);
// match self.class.call_method(py, "pre_tokenize", args, None) {
// Ok(res) => Ok(res
// .cast_as::<PyList>(py)
// .map_err(|_| {
// PyError::from("`pre_tokenize is expected to return a List[(str, (uint, uint))]")
// })?
// .extract::<Vec<(String, Offsets)>>()
// .map_err(|_| {
// PyError::from(
// "`pre_tokenize` is expected to return a List[(str, (uint, uint))]",
// )
// })?),
// Err(e) => {
// e.print(py);
// Err(Box::new(PyError::from(
// "Error while calling `pre_tokenize`",
// )))
// }
// }
// }
// }
//
impl Serialize for CustomPreTokenizer {
fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
Err(serde::ser::Error::custom(
"Custom PyPreTokenizer cannot be serialized",
"Custom PreTokenizer cannot be serialized",
))
}
}
@@ -346,16 +332,17 @@ impl<'de> Deserialize<'de> for CustomPreTokenizer {
where
D: Deserializer<'de>,
{
Err(serde::de::Error::custom("PyDecoder cannot be deserialized"))
Err(serde::de::Error::custom(
"Custom PreTokenizer cannot be deserialized",
))
}
}
#[derive(Clone, Deserialize)]
#[serde(untagged)]
pub(crate) enum PyPreTokenizerWrapper {
Sequence(Vec<Arc<PreTokenizerWrapper>>),
Custom(Arc<CustomPreTokenizer>),
Wrapped(Arc<PreTokenizerWrapper>),
Custom(CustomPreTokenizer),
Wrapped(PreTokenizerWrapper),
}
impl Serialize for PyPreTokenizerWrapper {
@@ -364,14 +351,32 @@ impl Serialize for PyPreTokenizerWrapper {
S: Serializer,
{
match self {
PyPreTokenizerWrapper::Sequence(seq) => {
PyPreTokenizerWrapper::Wrapped(inner) => inner.serialize(serializer),
PyPreTokenizerWrapper::Custom(inner) => inner.serialize(serializer),
}
}
}
#[derive(Clone, Deserialize)]
#[serde(untagged)]
pub(crate) enum PyPreTokenizerTypeWrapper {
Sequence(Vec<Arc<PyPreTokenizerWrapper>>),
Single(Arc<PyPreTokenizerWrapper>),
}
impl Serialize for PyPreTokenizerTypeWrapper {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
match self {
PyPreTokenizerTypeWrapper::Sequence(seq) => {
let mut ser = serializer.serialize_struct("Sequence", 2)?;
ser.serialize_field("type", "Sequence")?;
ser.serialize_field("pretokenizers", seq)?;
ser.end()
}
PyPreTokenizerWrapper::Wrapped(inner) => inner.serialize(serializer),
PyPreTokenizerWrapper::Custom(inner) => inner.serialize(serializer),
PyPreTokenizerTypeWrapper::Single(inner) => inner.serialize(serializer),
}
}
}
@@ -380,8 +385,17 @@ impl<I> From<I> for PyPreTokenizerWrapper
where
I: Into<PreTokenizerWrapper>,
{
fn from(norm: I) -> Self {
PyPreTokenizerWrapper::Wrapped(Arc::new(norm.into()))
fn from(pretok: I) -> Self {
PyPreTokenizerWrapper::Wrapped(pretok.into())
}
}
impl<I> From<I> for PyPreTokenizerTypeWrapper
where
I: Into<PyPreTokenizerWrapper>,
{
fn from(pretok: I) -> Self {
PyPreTokenizerTypeWrapper::Single(Arc::new(pretok.into()))
}
}
@@ -396,28 +410,36 @@ where
}
}
impl PreTokenizer for PyPreTokenizerWrapper {
fn pre_tokenize(&self, normalized: &mut PreTokenizedString) -> tk::Result<()> {
impl PreTokenizer for PyPreTokenizerTypeWrapper {
fn pre_tokenize(&self, pretok: &mut PreTokenizedString) -> tk::Result<()> {
match self {
PyPreTokenizerWrapper::Wrapped(inner) => inner.pre_tokenize(normalized),
PyPreTokenizerWrapper::Sequence(inner) => {
inner.iter().map(|n| n.pre_tokenize(normalized)).collect()
}
PyPreTokenizerWrapper::Custom(_) => {
unreachable!("Custom pretokenizers are currently disabled, how did you get here?")
PyPreTokenizerTypeWrapper::Single(inner) => inner.pre_tokenize(pretok),
PyPreTokenizerTypeWrapper::Sequence(inner) => {
inner.iter().map(|n| n.pre_tokenize(pretok)).collect()
}
}
}
}
impl PreTokenizer for PyPreTokenizerWrapper {
fn pre_tokenize(&self, pretok: &mut PreTokenizedString) -> tk::Result<()> {
match self {
PyPreTokenizerWrapper::Wrapped(inner) => inner.pre_tokenize(pretok),
PyPreTokenizerWrapper::Custom(inner) => inner.pre_tokenize(pretok),
}
}
}
#[cfg(test)]
mod test {
use pyo3::prelude::*;
use tk::pre_tokenizers::whitespace::Whitespace;
use tk::pre_tokenizers::sequence::Sequence;
use tk::pre_tokenizers::whitespace::{Whitespace, WhitespaceSplit};
use tk::pre_tokenizers::PreTokenizerWrapper;
use crate::pre_tokenizers::{CustomPreTokenizer, PyPreTokenizer, PyPreTokenizerWrapper};
use std::sync::Arc;
use crate::pre_tokenizers::{
CustomPreTokenizer, PyPreTokenizer, PyPreTokenizerTypeWrapper, PyPreTokenizerWrapper,
};
#[test]
fn get_subtype() {
@@ -439,18 +461,34 @@ mod test {
assert_eq!(py_ser, rs_ser);
let py_pretok: PyPreTokenizer = serde_json::from_str(&rs_ser).unwrap();
match py_pretok.pretok {
PyPreTokenizerWrapper::Wrapped(wsp) => match wsp.as_ref() {
PreTokenizerWrapper::Whitespace(_) => {}
PyPreTokenizerTypeWrapper::Single(inner) => match inner.as_ref() {
PyPreTokenizerWrapper::Wrapped(PreTokenizerWrapper::Whitespace(_)) => {}
_ => panic!("Expected Whitespace"),
},
_ => panic!("Expected wrapped, not custom."),
}
let gil = Python::acquire_gil();
let py = gil.python();
let py_wsp = PyPreTokenizer::new(Whitespace::default().into());
let obj: PyObject = Py::new(py, py_wsp).unwrap().into_py(py);
let py_seq: PyPreTokenizerWrapper =
PyPreTokenizerWrapper::Custom(Arc::new(CustomPreTokenizer::new(obj).unwrap()));
Sequence::new(vec![Whitespace::default().into(), WhitespaceSplit.into()]).into();
let py_wrapper_ser = serde_json::to_string(&py_seq).unwrap();
let rs_wrapped = PreTokenizerWrapper::Sequence(Sequence::new(vec![
Whitespace::default().into(),
WhitespaceSplit.into(),
]));
let rs_ser = serde_json::to_string(&rs_wrapped).unwrap();
assert_eq!(py_wrapper_ser, rs_ser);
let py_seq = PyPreTokenizer::new(py_seq.into());
let py_ser = serde_json::to_string(&py_seq).unwrap();
assert_eq!(py_wrapper_ser, py_ser);
let obj = Python::with_gil(|py| {
let py_wsp = PyPreTokenizer::new(Whitespace::default().into());
let obj: PyObject = Py::new(py, py_wsp).unwrap().into_py(py);
obj
});
let py_seq: PyPreTokenizerWrapper =
PyPreTokenizerWrapper::Custom(CustomPreTokenizer::new(obj));
assert!(serde_json::to_string(&py_seq).is_err());
}
}

View File

@@ -119,3 +119,26 @@ class TestDigits:
assert isinstance(Digits(True), Digits)
assert isinstance(Digits(False), Digits)
assert isinstance(pickle.loads(pickle.dumps(Digits())), Digits)
class TestCustomPreTokenizer:
class BadCustomPretok:
def pre_tokenize(self, pretok, wrong):
pass
class GoodCustomPretok:
def split(self, n, normalized):
return [normalized, normalized]
def pre_tokenize(self, pretok):
pretok.split(self.split)
def test_instantiate(self):
bad = PreTokenizer.custom(TestCustomPreTokenizer.BadCustomPretok())
good = PreTokenizer.custom(TestCustomPreTokenizer.GoodCustomPretok())
assert isinstance(bad, PreTokenizer)
assert isinstance(good, PreTokenizer)
with pytest.raises(Exception, match="TypeError: pre_tokenize()"):
bad.pre_tokenize("Hey there!")
assert good.pre_tokenize("Hey there!") == [("Hey there!", (0, 10)), ("Hey there!", (0, 10))]