use std::sync::Arc; use pyo3::exceptions; use pyo3::prelude::*; use pyo3::types::*; use serde::ser::SerializeStruct; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use tk::pre_tokenizers::bert::BertPreTokenizer; use tk::pre_tokenizers::byte_level::ByteLevel; use tk::pre_tokenizers::deduplication::Deduplication; use tk::pre_tokenizers::delimiter::CharDelimiterSplit; use tk::pre_tokenizers::metaspace::Metaspace; use tk::pre_tokenizers::punctuation::Punctuation; // use tk::pre_tokenizers::sequence::Sequence; use tk::pre_tokenizers::whitespace::{Whitespace, WhitespaceSplit}; use tk::pre_tokenizers::PreTokenizerWrapper; use tk::tokenizer::Offsets; use tk::{PreTokenizedString, PreTokenizer}; use tokenizers as tk; use super::error::ToPyResult; #[pyclass(dict, module = "tokenizers.pre_tokenizers", name=PreTokenizer)] #[derive(Clone, Serialize, Deserialize)] pub struct PyPreTokenizer { #[serde(flatten)] pub(crate) pretok: PyPreTokenizerWrapper, } impl PyPreTokenizer { #[allow(dead_code)] pub(crate) fn new(pretok: PyPreTokenizerWrapper) -> Self { PyPreTokenizer { pretok } } pub(crate) fn get_as_subtype(&self) -> PyResult { let base = self.clone(); let gil = Python::acquire_gil(); let py = gil.python(); match &self.pretok { PyPreTokenizerWrapper::Custom(_) => Py::new(py, base).map(Into::into), PyPreTokenizerWrapper::Sequence(_) => { Py::new(py, (PySequence {}, base)).map(Into::into) } PyPreTokenizerWrapper::Wrapped(inner) => match inner.as_ref() { PreTokenizerWrapper::Whitespace(_) => { Py::new(py, (PyWhitespace {}, base)).map(Into::into) } PreTokenizerWrapper::Deduplication(_) => { Py::new(py, (PyDeduplication {}, base)).map(Into::into) } PreTokenizerWrapper::Punctuation(_) => { Py::new(py, (PyPunctuation {}, base)).map(Into::into) } PreTokenizerWrapper::Sequence(_) => { Py::new(py, (PySequence {}, base)).map(Into::into) } PreTokenizerWrapper::Metaspace(_) => { Py::new(py, (PyMetaspace {}, base)).map(Into::into) } PreTokenizerWrapper::Delimiter(_) => { Py::new(py, (PyCharDelimiterSplit {}, base)).map(Into::into) } PreTokenizerWrapper::WhitespaceSplit(_) => { Py::new(py, (PyWhitespaceSplit {}, base)).map(Into::into) } PreTokenizerWrapper::ByteLevel(_) => { Py::new(py, (PyByteLevel {}, base)).map(Into::into) } PreTokenizerWrapper::BertPreTokenizer(_) => { Py::new(py, (PyBertPreTokenizer {}, base)).map(Into::into) } }, } } } impl PreTokenizer for PyPreTokenizer { fn pre_tokenize(&self, normalized: &mut PreTokenizedString) -> tk::Result<()> { self.pretok.pre_tokenize(normalized) } } #[pymethods] impl PyPreTokenizer { // #[staticmethod] // fn custom(pretok: PyObject) -> PyResult { // let py_pretok = CustomPreTokenizer::new(pretok)?; // Ok(PyPreTokenizer { // pretok: Arc::new(py_pretok), // }) // } fn __getstate__(&self, py: Python) -> PyResult { let data = serde_json::to_string(&self.pretok).map_err(|e| { exceptions::Exception::py_err(format!( "Error while attempting to pickle PreTokenizer: {}", e.to_string() )) })?; Ok(PyBytes::new(py, data.as_bytes()).to_object(py)) } fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { match state.extract::<&PyBytes>(py) { Ok(s) => { let unpickled = serde_json::from_slice(s.as_bytes()).map_err(|e| { exceptions::Exception::py_err(format!( "Error while attempting to unpickle PreTokenizer: {}", e.to_string() )) })?; self.pretok = unpickled; Ok(()) } Err(e) => Err(e), } } fn pre_tokenize(&self, s: &str) -> PyResult> { // TODO: Expose the PreTokenizedString let mut pretokenized = tk::tokenizer::PreTokenizedString::from(s); ToPyResult(self.pretok.pre_tokenize(&mut pretokenized)).into_py()?; Ok(pretokenized .get_splits(tk::OffsetReferential::Original) .into_iter() .map(|(s, o, _)| (s.to_owned(), o)) .collect()) } } #[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name=ByteLevel)] pub struct PyByteLevel {} #[pymethods] impl PyByteLevel { #[new] #[args(kwargs = "**")] fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PyPreTokenizer)> { let mut byte_level = ByteLevel::default(); if let Some(kwargs) = kwargs { for (key, value) in kwargs { let key: &str = key.extract()?; match key { "add_prefix_space" => { byte_level = byte_level.add_prefix_space(value.extract()?) } _ => println!("Ignored unknown kwargs option {}", key), } } } Ok((PyByteLevel {}, byte_level.into())) } #[staticmethod] fn alphabet() -> Vec { ByteLevel::alphabet() .into_iter() .map(|c| c.to_string()) .collect() } } #[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name=Whitespace)] pub struct PyWhitespace {} #[pymethods] impl PyWhitespace { #[new] fn new() -> PyResult<(Self, PyPreTokenizer)> { Ok((PyWhitespace {}, Whitespace::default().into())) } } #[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name=WhitespaceSplit)] pub struct PyWhitespaceSplit {} #[pymethods] impl PyWhitespaceSplit { #[new] fn new() -> PyResult<(Self, PyPreTokenizer)> { Ok((PyWhitespaceSplit {}, WhitespaceSplit.into())) } } #[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name=CharDelimiterSplit)] pub struct PyCharDelimiterSplit {} #[pymethods] impl PyCharDelimiterSplit { #[new] pub fn new(delimiter: &str) -> PyResult<(Self, PyPreTokenizer)> { let chr_delimiter = delimiter .chars() .nth(0) .ok_or(exceptions::Exception::py_err( "delimiter must be a single character", ))?; Ok(( PyCharDelimiterSplit {}, CharDelimiterSplit::new(chr_delimiter).into(), )) } fn __getnewargs__<'p>(&self, py: Python<'p>) -> PyResult<&'p PyTuple> { Ok(PyTuple::new(py, &[" "])) } } #[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name=BertPreTokenizer)] pub struct PyBertPreTokenizer {} #[pymethods] impl PyBertPreTokenizer { #[new] fn new() -> PyResult<(Self, PyPreTokenizer)> { Ok((PyBertPreTokenizer {}, BertPreTokenizer.into())) } } #[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name=Deduplication)] pub struct PyDeduplication {} #[pymethods] impl PyDeduplication { #[new] fn new() -> PyResult<(Self, PyPreTokenizer)> { Ok((PyDeduplication {}, Deduplication.into())) } } #[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name=Punctuation)] pub struct PyPunctuation {} #[pymethods] impl PyPunctuation { #[new] fn new() -> PyResult<(Self, PyPreTokenizer)> { Ok((PyPunctuation {}, Punctuation.into())) } } #[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name=Sequence)] pub struct PySequence {} #[pymethods] impl PySequence { #[new] fn new(pre_tokenizers: &PyList) -> PyResult<(Self, PyPreTokenizer)> { let mut sequence = Vec::with_capacity(pre_tokenizers.len()); for n in pre_tokenizers.iter() { let pretokenizer: PyRef = n.extract()?; match &pretokenizer.pretok { PyPreTokenizerWrapper::Sequence(inner) => { sequence.extend(inner.iter().map(|i| i.clone())) } PyPreTokenizerWrapper::Wrapped(inner) => sequence.push(inner.clone()), PyPreTokenizerWrapper::Custom(_) => unreachable!( "Custom pretokenizers are currently disabled, how did you get here?" ), } } Ok(( PySequence {}, PyPreTokenizer::new(PyPreTokenizerWrapper::Sequence(sequence)), )) } fn __getnewargs__<'p>(&self, py: Python<'p>) -> PyResult<&'p PyTuple> { Ok(PyTuple::new(py, &[PyList::empty(py)])) } } #[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name=Metaspace)] pub struct PyMetaspace {} #[pymethods] impl PyMetaspace { #[new] #[args(kwargs = "**")] fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PyPreTokenizer)> { let mut replacement = '▁'; let mut add_prefix_space = true; if let Some(kwargs) = kwargs { for (key, value) in kwargs { let key: &str = key.extract()?; match key { "replacement" => { let s: &str = value.extract()?; replacement = s.chars().nth(0).ok_or(exceptions::Exception::py_err( "replacement must be a character", ))?; } "add_prefix_space" => add_prefix_space = value.extract()?, _ => println!("Ignored unknown kwarg option {}", key), } } } Ok(( PyMetaspace {}, Metaspace::new(replacement, add_prefix_space).into(), )) } } // this is not accessible in python since the custom method is disabled. #[allow(dead_code)] pub(crate) struct CustomPreTokenizer { class: PyObject, } impl CustomPreTokenizer { #[allow(dead_code)] pub fn new(class: PyObject) -> PyResult { Ok(CustomPreTokenizer { class }) } } // impl tk::tokenizer::PreTokenizer for CustomPreTokenizer { // fn pre_tokenize(&self, sentence: &mut PreTokenizedString) -> tk::Result<()> { // let gil = Python::acquire_gil(); // let py = gil.python(); // // let args = PyTuple::new(py, &[sentence.get()]); // match self.class.call_method(py, "pre_tokenize", args, None) { // Ok(res) => Ok(res // .cast_as::(py) // .map_err(|_| { // PyError::from("`pre_tokenize is expected to return a List[(str, (uint, uint))]") // })? // .extract::>() // .map_err(|_| { // PyError::from( // "`pre_tokenize` is expected to return a List[(str, (uint, uint))]", // ) // })?), // Err(e) => { // e.print(py); // Err(Box::new(PyError::from( // "Error while calling `pre_tokenize`", // ))) // } // } // } // } // impl Serialize for CustomPreTokenizer { fn serialize(&self, _serializer: S) -> Result where S: Serializer, { Err(serde::ser::Error::custom( "Custom PyPreTokenizer cannot be serialized", )) } } impl<'de> Deserialize<'de> for CustomPreTokenizer { fn deserialize(_deserializer: D) -> Result where D: Deserializer<'de>, { Err(serde::de::Error::custom("PyDecoder cannot be deserialized")) } } #[derive(Clone, Deserialize)] #[serde(untagged)] pub(crate) enum PyPreTokenizerWrapper { Sequence(Vec>), Custom(Arc), Wrapped(Arc), } impl Serialize for PyPreTokenizerWrapper { fn serialize(&self, serializer: S) -> Result<::Ok, ::Error> where S: Serializer, { match self { PyPreTokenizerWrapper::Sequence(seq) => { let mut ser = serializer.serialize_struct("Sequence", 2)?; ser.serialize_field("type", "Sequence")?; ser.serialize_field("pretokenizers", seq)?; ser.end() } PyPreTokenizerWrapper::Wrapped(inner) => inner.serialize(serializer), PyPreTokenizerWrapper::Custom(inner) => inner.serialize(serializer), } } } impl From for PyPreTokenizerWrapper where I: Into, { fn from(norm: I) -> Self { PyPreTokenizerWrapper::Wrapped(Arc::new(norm.into())) } } impl From for PyPreTokenizer where I: Into, { fn from(pretok: I) -> Self { PyPreTokenizer { pretok: pretok.into().into(), } } } impl PreTokenizer for PyPreTokenizerWrapper { fn pre_tokenize(&self, normalized: &mut PreTokenizedString) -> tk::Result<()> { match self { PyPreTokenizerWrapper::Wrapped(inner) => inner.pre_tokenize(normalized), PyPreTokenizerWrapper::Sequence(inner) => { inner.iter().map(|n| n.pre_tokenize(normalized)).collect() } PyPreTokenizerWrapper::Custom(_) => { unreachable!("Custom pretokenizers are currently disabled, how did you get here?") } } } } #[cfg(test)] mod test { use pyo3::{AsPyRef, Py, PyObject, Python}; use tk::pre_tokenizers::whitespace::Whitespace; use tk::pre_tokenizers::PreTokenizerWrapper; use crate::pre_tokenizers::{CustomPreTokenizer, PyPreTokenizer, PyPreTokenizerWrapper}; use std::sync::Arc; #[test] fn get_subtype() { let py_norm = PyPreTokenizer::new(Whitespace::default().into()); let py_wsp = py_norm.get_as_subtype().unwrap(); let gil = Python::acquire_gil(); assert_eq!( "tokenizers.pre_tokenizers.Whitespace", py_wsp.as_ref(gil.python()).get_type().name() ); } #[test] fn serialize() { let py_wrapped: PyPreTokenizerWrapper = Whitespace::default().into(); let py_ser = serde_json::to_string(&py_wrapped).unwrap(); let rs_wrapped = PreTokenizerWrapper::Whitespace(Whitespace::default()); let rs_ser = serde_json::to_string(&rs_wrapped).unwrap(); assert_eq!(py_ser, rs_ser); let py_pretok: PyPreTokenizer = serde_json::from_str(&rs_ser).unwrap(); match py_pretok.pretok { PyPreTokenizerWrapper::Wrapped(wsp) => match wsp.as_ref() { PreTokenizerWrapper::Whitespace(_) => {} _ => panic!("Expected Whitespace"), }, _ => panic!("Expected wrapped, not custom."), } let gil = Python::acquire_gil(); let py_wsp = PyPreTokenizer::new(Whitespace::default().into()); let obj: PyObject = Py::new(gil.python(), py_wsp).unwrap().into(); let py_seq: PyPreTokenizerWrapper = PyPreTokenizerWrapper::Custom(Arc::new(CustomPreTokenizer::new(obj).unwrap())); assert!(serde_json::to_string(&py_seq).is_err()); } }