mirror of
https://github.com/mii443/tokenizers.git
synced 2025-12-05 20:28:22 +00:00
Python - Add support for custom PreTokenizer
This commit is contained in:
@@ -19,17 +19,18 @@ use tk::{PreTokenizedString, PreTokenizer};
|
||||
use tokenizers as tk;
|
||||
|
||||
use super::error::ToPyResult;
|
||||
use super::utils::*;
|
||||
|
||||
#[pyclass(dict, module = "tokenizers.pre_tokenizers", name=PreTokenizer)]
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
pub struct PyPreTokenizer {
|
||||
#[serde(flatten)]
|
||||
pub(crate) pretok: PyPreTokenizerWrapper,
|
||||
pub(crate) pretok: PyPreTokenizerTypeWrapper,
|
||||
}
|
||||
|
||||
impl PyPreTokenizer {
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn new(pretok: PyPreTokenizerWrapper) -> Self {
|
||||
pub(crate) fn new(pretok: PyPreTokenizerTypeWrapper) -> Self {
|
||||
PyPreTokenizer { pretok }
|
||||
}
|
||||
|
||||
@@ -38,16 +39,21 @@ impl PyPreTokenizer {
|
||||
let gil = Python::acquire_gil();
|
||||
let py = gil.python();
|
||||
Ok(match &self.pretok {
|
||||
PyPreTokenizerTypeWrapper::Sequence(_) => {
|
||||
Py::new(py, (PySequence {}, base))?.into_py(py)
|
||||
}
|
||||
PyPreTokenizerTypeWrapper::Single(ref inner) => match inner.as_ref() {
|
||||
PyPreTokenizerWrapper::Custom(_) => Py::new(py, base)?.into_py(py),
|
||||
PyPreTokenizerWrapper::Sequence(_) => Py::new(py, (PySequence {}, base))?.into_py(py),
|
||||
PyPreTokenizerWrapper::Wrapped(inner) => match inner.as_ref() {
|
||||
PyPreTokenizerWrapper::Wrapped(inner) => match inner {
|
||||
PreTokenizerWrapper::Whitespace(_) => {
|
||||
Py::new(py, (PyWhitespace {}, base))?.into_py(py)
|
||||
}
|
||||
PreTokenizerWrapper::Punctuation(_) => {
|
||||
Py::new(py, (PyPunctuation {}, base))?.into_py(py)
|
||||
}
|
||||
PreTokenizerWrapper::Sequence(_) => Py::new(py, (PySequence {}, base))?.into_py(py),
|
||||
PreTokenizerWrapper::Sequence(_) => {
|
||||
Py::new(py, (PySequence {}, base))?.into_py(py)
|
||||
}
|
||||
PreTokenizerWrapper::Metaspace(_) => {
|
||||
Py::new(py, (PyMetaspace {}, base))?.into_py(py)
|
||||
}
|
||||
@@ -65,6 +71,7 @@ impl PyPreTokenizer {
|
||||
}
|
||||
PreTokenizerWrapper::Digits(_) => Py::new(py, (PyDigits {}, base))?.into_py(py),
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -77,13 +84,12 @@ impl PreTokenizer for PyPreTokenizer {
|
||||
|
||||
#[pymethods]
|
||||
impl PyPreTokenizer {
|
||||
// #[staticmethod]
|
||||
// fn custom(pretok: PyObject) -> PyResult<Self> {
|
||||
// let py_pretok = CustomPreTokenizer::new(pretok)?;
|
||||
// Ok(PyPreTokenizer {
|
||||
// pretok: Arc::new(py_pretok),
|
||||
// })
|
||||
// }
|
||||
#[staticmethod]
|
||||
fn custom(pretok: PyObject) -> PyResult<Self> {
|
||||
Ok(PyPreTokenizer {
|
||||
pretok: PyPreTokenizerWrapper::Custom(CustomPreTokenizer::new(pretok)).into(),
|
||||
})
|
||||
}
|
||||
|
||||
fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
|
||||
let data = serde_json::to_string(&self.pretok).map_err(|e| {
|
||||
@@ -227,16 +233,15 @@ impl PySequence {
|
||||
for n in pre_tokenizers.iter() {
|
||||
let pretokenizer: PyRef<PyPreTokenizer> = n.extract()?;
|
||||
match &pretokenizer.pretok {
|
||||
PyPreTokenizerWrapper::Sequence(inner) => sequence.extend(inner.iter().cloned()),
|
||||
PyPreTokenizerWrapper::Wrapped(inner) => sequence.push(inner.clone()),
|
||||
PyPreTokenizerWrapper::Custom(_) => unreachable!(
|
||||
"Custom pretokenizers are currently disabled, how did you get here?"
|
||||
),
|
||||
PyPreTokenizerTypeWrapper::Sequence(inner) => {
|
||||
sequence.extend(inner.iter().cloned())
|
||||
}
|
||||
PyPreTokenizerTypeWrapper::Single(inner) => sequence.push(inner.clone()),
|
||||
}
|
||||
}
|
||||
Ok((
|
||||
PySequence {},
|
||||
PyPreTokenizer::new(PyPreTokenizerWrapper::Sequence(sequence)),
|
||||
PyPreTokenizer::new(PyPreTokenizerTypeWrapper::Sequence(sequence)),
|
||||
))
|
||||
}
|
||||
|
||||
@@ -289,54 +294,35 @@ impl PyDigits {
|
||||
}
|
||||
}
|
||||
|
||||
// this is not accessible in python since the custom method is disabled.
|
||||
#[allow(dead_code)]
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct CustomPreTokenizer {
|
||||
class: PyObject,
|
||||
inner: PyObject,
|
||||
}
|
||||
|
||||
impl CustomPreTokenizer {
|
||||
#[allow(dead_code)]
|
||||
pub fn new(class: PyObject) -> PyResult<Self> {
|
||||
Ok(CustomPreTokenizer { class })
|
||||
pub fn new(inner: PyObject) -> Self {
|
||||
Self { inner }
|
||||
}
|
||||
}
|
||||
|
||||
impl tk::tokenizer::PreTokenizer for CustomPreTokenizer {
|
||||
fn pre_tokenize(&self, sentence: &mut PreTokenizedString) -> tk::Result<()> {
|
||||
Python::with_gil(|py| {
|
||||
let pretok = PyPreTokenizedStringRefMut::new(sentence);
|
||||
let py_pretok = self.inner.as_ref(py);
|
||||
py_pretok.call_method("pre_tokenize", (pretok.get(),), None)?;
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// impl tk::tokenizer::PreTokenizer for CustomPreTokenizer {
|
||||
// fn pre_tokenize(&self, sentence: &mut PreTokenizedString) -> tk::Result<()> {
|
||||
// let gil = Python::acquire_gil();
|
||||
// let py = gil.python();
|
||||
//
|
||||
// let args = PyTuple::new(py, &[sentence.get()]);
|
||||
// match self.class.call_method(py, "pre_tokenize", args, None) {
|
||||
// Ok(res) => Ok(res
|
||||
// .cast_as::<PyList>(py)
|
||||
// .map_err(|_| {
|
||||
// PyError::from("`pre_tokenize is expected to return a List[(str, (uint, uint))]")
|
||||
// })?
|
||||
// .extract::<Vec<(String, Offsets)>>()
|
||||
// .map_err(|_| {
|
||||
// PyError::from(
|
||||
// "`pre_tokenize` is expected to return a List[(str, (uint, uint))]",
|
||||
// )
|
||||
// })?),
|
||||
// Err(e) => {
|
||||
// e.print(py);
|
||||
// Err(Box::new(PyError::from(
|
||||
// "Error while calling `pre_tokenize`",
|
||||
// )))
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
//
|
||||
impl Serialize for CustomPreTokenizer {
|
||||
fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
Err(serde::ser::Error::custom(
|
||||
"Custom PyPreTokenizer cannot be serialized",
|
||||
"Custom PreTokenizer cannot be serialized",
|
||||
))
|
||||
}
|
||||
}
|
||||
@@ -346,16 +332,17 @@ impl<'de> Deserialize<'de> for CustomPreTokenizer {
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
Err(serde::de::Error::custom("PyDecoder cannot be deserialized"))
|
||||
Err(serde::de::Error::custom(
|
||||
"Custom PreTokenizer cannot be deserialized",
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub(crate) enum PyPreTokenizerWrapper {
|
||||
Sequence(Vec<Arc<PreTokenizerWrapper>>),
|
||||
Custom(Arc<CustomPreTokenizer>),
|
||||
Wrapped(Arc<PreTokenizerWrapper>),
|
||||
Custom(CustomPreTokenizer),
|
||||
Wrapped(PreTokenizerWrapper),
|
||||
}
|
||||
|
||||
impl Serialize for PyPreTokenizerWrapper {
|
||||
@@ -364,14 +351,32 @@ impl Serialize for PyPreTokenizerWrapper {
|
||||
S: Serializer,
|
||||
{
|
||||
match self {
|
||||
PyPreTokenizerWrapper::Sequence(seq) => {
|
||||
PyPreTokenizerWrapper::Wrapped(inner) => inner.serialize(serializer),
|
||||
PyPreTokenizerWrapper::Custom(inner) => inner.serialize(serializer),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub(crate) enum PyPreTokenizerTypeWrapper {
|
||||
Sequence(Vec<Arc<PyPreTokenizerWrapper>>),
|
||||
Single(Arc<PyPreTokenizerWrapper>),
|
||||
}
|
||||
|
||||
impl Serialize for PyPreTokenizerTypeWrapper {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
match self {
|
||||
PyPreTokenizerTypeWrapper::Sequence(seq) => {
|
||||
let mut ser = serializer.serialize_struct("Sequence", 2)?;
|
||||
ser.serialize_field("type", "Sequence")?;
|
||||
ser.serialize_field("pretokenizers", seq)?;
|
||||
ser.end()
|
||||
}
|
||||
PyPreTokenizerWrapper::Wrapped(inner) => inner.serialize(serializer),
|
||||
PyPreTokenizerWrapper::Custom(inner) => inner.serialize(serializer),
|
||||
PyPreTokenizerTypeWrapper::Single(inner) => inner.serialize(serializer),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -380,8 +385,17 @@ impl<I> From<I> for PyPreTokenizerWrapper
|
||||
where
|
||||
I: Into<PreTokenizerWrapper>,
|
||||
{
|
||||
fn from(norm: I) -> Self {
|
||||
PyPreTokenizerWrapper::Wrapped(Arc::new(norm.into()))
|
||||
fn from(pretok: I) -> Self {
|
||||
PyPreTokenizerWrapper::Wrapped(pretok.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl<I> From<I> for PyPreTokenizerTypeWrapper
|
||||
where
|
||||
I: Into<PyPreTokenizerWrapper>,
|
||||
{
|
||||
fn from(pretok: I) -> Self {
|
||||
PyPreTokenizerTypeWrapper::Single(Arc::new(pretok.into()))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -396,28 +410,36 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl PreTokenizer for PyPreTokenizerWrapper {
|
||||
fn pre_tokenize(&self, normalized: &mut PreTokenizedString) -> tk::Result<()> {
|
||||
impl PreTokenizer for PyPreTokenizerTypeWrapper {
|
||||
fn pre_tokenize(&self, pretok: &mut PreTokenizedString) -> tk::Result<()> {
|
||||
match self {
|
||||
PyPreTokenizerWrapper::Wrapped(inner) => inner.pre_tokenize(normalized),
|
||||
PyPreTokenizerWrapper::Sequence(inner) => {
|
||||
inner.iter().map(|n| n.pre_tokenize(normalized)).collect()
|
||||
PyPreTokenizerTypeWrapper::Single(inner) => inner.pre_tokenize(pretok),
|
||||
PyPreTokenizerTypeWrapper::Sequence(inner) => {
|
||||
inner.iter().map(|n| n.pre_tokenize(pretok)).collect()
|
||||
}
|
||||
PyPreTokenizerWrapper::Custom(_) => {
|
||||
unreachable!("Custom pretokenizers are currently disabled, how did you get here?")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PreTokenizer for PyPreTokenizerWrapper {
|
||||
fn pre_tokenize(&self, pretok: &mut PreTokenizedString) -> tk::Result<()> {
|
||||
match self {
|
||||
PyPreTokenizerWrapper::Wrapped(inner) => inner.pre_tokenize(pretok),
|
||||
PyPreTokenizerWrapper::Custom(inner) => inner.pre_tokenize(pretok),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use pyo3::prelude::*;
|
||||
use tk::pre_tokenizers::whitespace::Whitespace;
|
||||
use tk::pre_tokenizers::sequence::Sequence;
|
||||
use tk::pre_tokenizers::whitespace::{Whitespace, WhitespaceSplit};
|
||||
use tk::pre_tokenizers::PreTokenizerWrapper;
|
||||
|
||||
use crate::pre_tokenizers::{CustomPreTokenizer, PyPreTokenizer, PyPreTokenizerWrapper};
|
||||
use std::sync::Arc;
|
||||
use crate::pre_tokenizers::{
|
||||
CustomPreTokenizer, PyPreTokenizer, PyPreTokenizerTypeWrapper, PyPreTokenizerWrapper,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn get_subtype() {
|
||||
@@ -439,18 +461,34 @@ mod test {
|
||||
assert_eq!(py_ser, rs_ser);
|
||||
let py_pretok: PyPreTokenizer = serde_json::from_str(&rs_ser).unwrap();
|
||||
match py_pretok.pretok {
|
||||
PyPreTokenizerWrapper::Wrapped(wsp) => match wsp.as_ref() {
|
||||
PreTokenizerWrapper::Whitespace(_) => {}
|
||||
PyPreTokenizerTypeWrapper::Single(inner) => match inner.as_ref() {
|
||||
PyPreTokenizerWrapper::Wrapped(PreTokenizerWrapper::Whitespace(_)) => {}
|
||||
_ => panic!("Expected Whitespace"),
|
||||
},
|
||||
_ => panic!("Expected wrapped, not custom."),
|
||||
}
|
||||
let gil = Python::acquire_gil();
|
||||
let py = gil.python();
|
||||
|
||||
let py_seq: PyPreTokenizerWrapper =
|
||||
Sequence::new(vec![Whitespace::default().into(), WhitespaceSplit.into()]).into();
|
||||
let py_wrapper_ser = serde_json::to_string(&py_seq).unwrap();
|
||||
let rs_wrapped = PreTokenizerWrapper::Sequence(Sequence::new(vec![
|
||||
Whitespace::default().into(),
|
||||
WhitespaceSplit.into(),
|
||||
]));
|
||||
let rs_ser = serde_json::to_string(&rs_wrapped).unwrap();
|
||||
assert_eq!(py_wrapper_ser, rs_ser);
|
||||
|
||||
let py_seq = PyPreTokenizer::new(py_seq.into());
|
||||
let py_ser = serde_json::to_string(&py_seq).unwrap();
|
||||
assert_eq!(py_wrapper_ser, py_ser);
|
||||
|
||||
let obj = Python::with_gil(|py| {
|
||||
let py_wsp = PyPreTokenizer::new(Whitespace::default().into());
|
||||
let obj: PyObject = Py::new(py, py_wsp).unwrap().into_py(py);
|
||||
obj
|
||||
});
|
||||
let py_seq: PyPreTokenizerWrapper =
|
||||
PyPreTokenizerWrapper::Custom(Arc::new(CustomPreTokenizer::new(obj).unwrap()));
|
||||
PyPreTokenizerWrapper::Custom(CustomPreTokenizer::new(obj));
|
||||
assert!(serde_json::to_string(&py_seq).is_err());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -119,3 +119,26 @@ class TestDigits:
|
||||
assert isinstance(Digits(True), Digits)
|
||||
assert isinstance(Digits(False), Digits)
|
||||
assert isinstance(pickle.loads(pickle.dumps(Digits())), Digits)
|
||||
|
||||
|
||||
class TestCustomPreTokenizer:
|
||||
class BadCustomPretok:
|
||||
def pre_tokenize(self, pretok, wrong):
|
||||
pass
|
||||
|
||||
class GoodCustomPretok:
|
||||
def split(self, n, normalized):
|
||||
return [normalized, normalized]
|
||||
|
||||
def pre_tokenize(self, pretok):
|
||||
pretok.split(self.split)
|
||||
|
||||
def test_instantiate(self):
|
||||
bad = PreTokenizer.custom(TestCustomPreTokenizer.BadCustomPretok())
|
||||
good = PreTokenizer.custom(TestCustomPreTokenizer.GoodCustomPretok())
|
||||
|
||||
assert isinstance(bad, PreTokenizer)
|
||||
assert isinstance(good, PreTokenizer)
|
||||
with pytest.raises(Exception, match="TypeError: pre_tokenize()"):
|
||||
bad.pre_tokenize("Hey there!")
|
||||
assert good.pre_tokenize("Hey there!") == [("Hey there!", (0, 10)), ("Hey there!", (0, 10))]
|
||||
|
||||
Reference in New Issue
Block a user