Remove Container from PreTokenizers, replace with Arc.

* prefix the Python types in Rust with Py, rename PyPretokenizer
  to CustomPretokenizer
* remove unsound Container wrappers, replace with Arc
* change panic on trying to (de-)serialize custom pretokenizer to
  exception
This commit is contained in:
Sebastian Pütz
2020-07-25 18:22:49 +02:00
committed by Anthony MOI
parent bcc54a2ea1
commit b411443128
3 changed files with 127 additions and 129 deletions

View File

@ -58,13 +58,13 @@ fn models(_py: Python, m: &PyModule) -> PyResult<()> {
/// PreTokenizers Module
#[pymodule]
fn pre_tokenizers(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<pre_tokenizers::PreTokenizer>()?;
m.add_class::<pre_tokenizers::ByteLevel>()?;
m.add_class::<pre_tokenizers::Whitespace>()?;
m.add_class::<pre_tokenizers::WhitespaceSplit>()?;
m.add_class::<pre_tokenizers::BertPreTokenizer>()?;
m.add_class::<pre_tokenizers::Metaspace>()?;
m.add_class::<pre_tokenizers::CharDelimiterSplit>()?;
m.add_class::<pre_tokenizers::PyPreTokenizer>()?;
m.add_class::<pre_tokenizers::PyByteLevel>()?;
m.add_class::<pre_tokenizers::PyWhitespace>()?;
m.add_class::<pre_tokenizers::PyWhitespaceSplit>()?;
m.add_class::<pre_tokenizers::PyBertPreTokenizer>()?;
m.add_class::<pre_tokenizers::PyMetaspace>()?;
m.add_class::<pre_tokenizers::PyCharDelimiterSplit>()?;
Ok(())
}

View File

@ -1,31 +1,70 @@
extern crate tokenizers as tk;
use std::sync::Arc;
use super::error::ToPyResult;
use super::utils::Container;
use pyo3::exceptions;
use pyo3::prelude::*;
use pyo3::types::*;
use tk::tokenizer::Offsets;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
#[pyclass(dict, module = "tokenizers.pre_tokenizers")]
pub struct PreTokenizer {
pub pretok: Container<dyn tk::tokenizer::PreTokenizer>,
use tk::pre_tokenizers::bert::BertPreTokenizer;
use tk::pre_tokenizers::byte_level::ByteLevel;
use tk::pre_tokenizers::delimiter::CharDelimiterSplit;
use tk::pre_tokenizers::metaspace::Metaspace;
use tk::pre_tokenizers::whitespace::{Whitespace, WhitespaceSplit};
use tk::tokenizer::Offsets;
use tk::{PreTokenizedString, PreTokenizer};
use tokenizers as tk;
use super::error::ToPyResult;
#[pyclass(dict, module = "tokenizers.pre_tokenizers", name=PreTokenizer)]
#[derive(Clone)]
pub struct PyPreTokenizer {
pub pretok: Arc<dyn PreTokenizer>,
}
impl PyPreTokenizer {
pub fn new(pretok: Arc<dyn PreTokenizer>) -> Self {
PyPreTokenizer { pretok }
}
}
impl Serialize for PyPreTokenizer {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
self.pretok.serialize(serializer)
}
}
impl<'de> Deserialize<'de> for PyPreTokenizer {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
Arc::deserialize(deserializer).map(PyPreTokenizer::new)
}
}
#[typetag::serde]
impl PreTokenizer for PyPreTokenizer {
fn pre_tokenize(&self, normalized: &mut PreTokenizedString) -> tk::Result<()> {
self.pretok.pre_tokenize(normalized)
}
}
#[pymethods]
impl PreTokenizer {
impl PyPreTokenizer {
// #[staticmethod]
// fn custom(pretok: PyObject) -> PyResult<Self> {
// let py_pretok = PyPreTokenizer::new(pretok)?;
// Ok(PreTokenizer {
// pretok: Container::Owned(Box::new(py_pretok)),
// let py_pretok = CustomPreTokenizer::new(pretok)?;
// Ok(PyPreTokenizer {
// pretok: Arc::new(py_pretok),
// })
// }
fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
let data = self
.pretok
.execute(|pretok| serde_json::to_string(&pretok))
.map_err(|e| {
let data = serde_json::to_string(&self.pretok.as_ref()).map_err(|e| {
exceptions::Exception::py_err(format!(
"Error while attempting to pickle PreTokenizer: {}",
e.to_string()
@ -37,13 +76,13 @@ impl PreTokenizer {
fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
match state.extract::<&PyBytes>(py) {
Ok(s) => {
self.pretok =
Container::Owned(serde_json::from_slice(s.as_bytes()).map_err(|e| {
let unpickled = serde_json::from_slice(s.as_bytes()).map_err(|e| {
exceptions::Exception::py_err(format!(
"Error while attempting to unpickle PreTokenizer: {}",
e.to_string()
))
})?);
})?;
self.pretok = unpickled;
Ok(())
}
Err(e) => Err(e),
@ -54,11 +93,7 @@ impl PreTokenizer {
// TODO: Expose the PreTokenizedString
let mut pretokenized = tk::tokenizer::PreTokenizedString::from(s);
ToPyResult(
self.pretok
.execute(|pretok| pretok.pre_tokenize(&mut pretokenized)),
)
.into_py()?;
ToPyResult(self.pretok.pre_tokenize(&mut pretokenized)).into_py()?;
Ok(pretokenized
.get_normalized(tk::OffsetReferential::Original)
@ -68,14 +103,14 @@ impl PreTokenizer {
}
}
#[pyclass(extends=PreTokenizer, module = "tokenizers.pre_tokenizers")]
pub struct ByteLevel {}
#[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name=ByteLevel)]
pub struct PyByteLevel {}
#[pymethods]
impl ByteLevel {
impl PyByteLevel {
#[new]
#[args(kwargs = "**")]
fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PreTokenizer)> {
let mut byte_level = tk::pre_tokenizers::byte_level::ByteLevel::default();
fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PyPreTokenizer)> {
let mut byte_level = ByteLevel::default();
if let Some(kwargs) = kwargs {
for (key, value) in kwargs {
let key: &str = key.extract()?;
@ -88,61 +123,50 @@ impl ByteLevel {
}
}
Ok((
ByteLevel {},
PreTokenizer {
pretok: Container::Owned(Box::new(byte_level)),
},
))
Ok((PyByteLevel {}, PyPreTokenizer::new(Arc::new(byte_level))))
}
#[staticmethod]
fn alphabet() -> Vec<String> {
tk::pre_tokenizers::byte_level::ByteLevel::alphabet()
ByteLevel::alphabet()
.into_iter()
.map(|c| c.to_string())
.collect()
}
}
#[pyclass(extends=PreTokenizer, module = "tokenizers.pre_tokenizers")]
pub struct Whitespace {}
#[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name=Whitespace)]
pub struct PyWhitespace {}
#[pymethods]
impl Whitespace {
impl PyWhitespace {
#[new]
fn new() -> PyResult<(Self, PreTokenizer)> {
fn new() -> PyResult<(Self, PyPreTokenizer)> {
Ok((
Whitespace {},
PreTokenizer {
pretok: Container::Owned(Box::new(
tk::pre_tokenizers::whitespace::Whitespace::default(),
)),
},
PyWhitespace {},
PyPreTokenizer::new(Arc::new(Whitespace::default())),
))
}
}
#[pyclass(extends=PreTokenizer, module = "tokenizers.pre_tokenizers")]
pub struct WhitespaceSplit {}
#[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name=WhitespaceSplit)]
pub struct PyWhitespaceSplit {}
#[pymethods]
impl WhitespaceSplit {
impl PyWhitespaceSplit {
#[new]
fn new() -> PyResult<(Self, PreTokenizer)> {
fn new() -> PyResult<(Self, PyPreTokenizer)> {
Ok((
WhitespaceSplit {},
PreTokenizer {
pretok: Container::Owned(Box::new(tk::pre_tokenizers::whitespace::WhitespaceSplit)),
},
PyWhitespaceSplit {},
PyPreTokenizer::new(Arc::new(WhitespaceSplit)),
))
}
}
#[pyclass(extends=PreTokenizer, module = "tokenizers.pre_tokenizers")]
pub struct CharDelimiterSplit {}
#[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name=CharDelimiterSplit)]
pub struct PyCharDelimiterSplit {}
#[pymethods]
impl CharDelimiterSplit {
impl PyCharDelimiterSplit {
#[new]
pub fn new(delimiter: &str) -> PyResult<(Self, PreTokenizer)> {
pub fn new(delimiter: &str) -> PyResult<(Self, PyPreTokenizer)> {
let chr_delimiter = delimiter
.chars()
.nth(0)
@ -150,12 +174,8 @@ impl CharDelimiterSplit {
"delimiter must be a single character",
))?;
Ok((
CharDelimiterSplit {},
PreTokenizer {
pretok: Container::Owned(Box::new(
tk::pre_tokenizers::delimiter::CharDelimiterSplit::new(chr_delimiter),
)),
},
PyCharDelimiterSplit {},
PyPreTokenizer::new(Arc::new(CharDelimiterSplit::new(chr_delimiter))),
))
}
@ -164,28 +184,26 @@ impl CharDelimiterSplit {
}
}
#[pyclass(extends=PreTokenizer, module = "tokenizers.pre_tokenizers")]
pub struct BertPreTokenizer {}
#[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name=BertPreTokenizer)]
pub struct PyBertPreTokenizer {}
#[pymethods]
impl BertPreTokenizer {
impl PyBertPreTokenizer {
#[new]
fn new() -> PyResult<(Self, PreTokenizer)> {
fn new() -> PyResult<(Self, PyPreTokenizer)> {
Ok((
BertPreTokenizer {},
PreTokenizer {
pretok: Container::Owned(Box::new(tk::pre_tokenizers::bert::BertPreTokenizer)),
},
PyBertPreTokenizer {},
PyPreTokenizer::new(Arc::new(BertPreTokenizer)),
))
}
}
#[pyclass(extends=PreTokenizer, module = "tokenizers.pre_tokenizers")]
pub struct Metaspace {}
#[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name=Metaspace)]
pub struct PyMetaspace {}
#[pymethods]
impl Metaspace {
impl PyMetaspace {
#[new]
#[args(kwargs = "**")]
fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PreTokenizer)> {
fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PyPreTokenizer)> {
let mut replacement = '▁';
let mut add_prefix_space = true;
@ -206,33 +224,25 @@ impl Metaspace {
}
Ok((
Metaspace {},
PreTokenizer {
pretok: Container::Owned(Box::new(tk::pre_tokenizers::metaspace::Metaspace::new(
replacement,
add_prefix_space,
))),
},
PyMetaspace {},
PyPreTokenizer::new(Arc::new(Metaspace::new(replacement, add_prefix_space))),
))
}
}
// struct PyPreTokenizer {
// struct CustomPreTokenizer {
// class: PyObject,
// }
//
// impl PyPreTokenizer {
// impl CustomPreTokenizer {
// pub fn new(class: PyObject) -> PyResult<Self> {
// Ok(PyPreTokenizer { class })
// Ok(CustomPreTokenizer { class })
// }
// }
//
// #[typetag::serde]
// impl tk::tokenizer::PreTokenizer for PyPreTokenizer {
// fn pre_tokenize(
// &self,
// sentence: &mut tk::tokenizer::NormalizedString,
// ) -> Result<Vec<(String, Offsets)>> {
// impl tk::tokenizer::PreTokenizer for CustomPreTokenizer {
// fn pre_tokenize(&self, sentence: &mut NormalizedString) -> tk::Result<Vec<(String, Offsets)>> {
// let gil = Python::acquire_gil();
// let py = gil.python();
//
@ -259,8 +269,8 @@ impl Metaspace {
// }
// }
//
// impl Serialize for PyPreTokenizer {
// fn serialize<S>(&self, _serializer: S) -> std::result::Result<S::Ok, S::Error>
// impl Serialize for CustomPreTokenizer {
// fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
// where
// S: Serializer,
// {
@ -270,11 +280,11 @@ impl Metaspace {
// }
// }
//
// impl<'de> Deserialize<'de> for PyPreTokenizer {
// fn deserialize<D>(_deserializer: D) -> std::result::Result<Self, D::Error>
// impl<'de> Deserialize<'de> for CustomPreTokenizer {
// fn deserialize<D>(_deserializer: D) -> Result<Self, D::Error>
// where
// D: Deserializer<'de>,
// {
// unimplemented!("PyPreTokenizer cannot be deserialized")
// Err(D::Error::custom("PyDecoder cannot be deserialized"))
// }
// }

View File

@ -17,7 +17,7 @@ use super::encoding::PyEncoding;
use super::error::{PyError, ToPyResult};
use super::models::PyModel;
use super::normalizers::PyNormalizer;
use super::pre_tokenizers::PreTokenizer;
use super::pre_tokenizers::PyPreTokenizer;
use super::processors::PostProcessor;
use super::trainers::PyTrainer;
use super::utils::Container;
@ -268,7 +268,7 @@ impl From<PreTokenizedEncodeInput> for tk::tokenizer::EncodeInput {
}
}
type TokenizerImpl = Tokenizer<PyModel, PyNormalizer>;
type TokenizerImpl = Tokenizer<PyModel, PyNormalizer, PyPreTokenizer>;
#[pyclass(dict, module = "tokenizers")]
pub struct PyTokenizer {
@ -717,25 +717,13 @@ impl PyTokenizer {
}
#[getter]
fn get_pre_tokenizer(&self) -> PyResult<Option<PreTokenizer>> {
Ok(self
.tokenizer
.get_pre_tokenizer()
.map(|pretok| PreTokenizer {
pretok: Container::from_ref(pretok),
}))
fn get_pre_tokenizer(&self) -> Option<PyPreTokenizer> {
self.tokenizer.get_pre_tokenizer().cloned()
}
#[setter]
fn set_pre_tokenizer(&mut self, mut pretok: PyRefMut<PreTokenizer>) -> PyResult<()> {
if let Some(pretok) = pretok.pretok.to_pointer() {
self.tokenizer.with_pre_tokenizer(pretok);
Ok(())
} else {
Err(exceptions::Exception::py_err(
"The PreTokenizer is already being used in another Tokenizer",
))
}
fn set_pre_tokenizer(&mut self, pretok: PyRef<PyPreTokenizer>) {
self.tokenizer.with_pre_tokenizer(pretok.clone());
}
#[getter]