Remove Container from PreTokenizers, replace with Arc.

* prefix the Python types in Rust with Py, rename PyPretokenizer
  to CustomPretokenizer
* remove unsound Container wrappers, replace with Arc
* change panic on trying to (de-)serialize custom pretokenizer to
  exception
This commit is contained in:
Sebastian Pütz
2020-07-25 18:22:49 +02:00
committed by Anthony MOI
parent bcc54a2ea1
commit b411443128
3 changed files with 127 additions and 129 deletions

View File

@ -58,13 +58,13 @@ fn models(_py: Python, m: &PyModule) -> PyResult<()> {
/// PreTokenizers Module /// PreTokenizers Module
#[pymodule] #[pymodule]
fn pre_tokenizers(_py: Python, m: &PyModule) -> PyResult<()> { fn pre_tokenizers(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<pre_tokenizers::PreTokenizer>()?; m.add_class::<pre_tokenizers::PyPreTokenizer>()?;
m.add_class::<pre_tokenizers::ByteLevel>()?; m.add_class::<pre_tokenizers::PyByteLevel>()?;
m.add_class::<pre_tokenizers::Whitespace>()?; m.add_class::<pre_tokenizers::PyWhitespace>()?;
m.add_class::<pre_tokenizers::WhitespaceSplit>()?; m.add_class::<pre_tokenizers::PyWhitespaceSplit>()?;
m.add_class::<pre_tokenizers::BertPreTokenizer>()?; m.add_class::<pre_tokenizers::PyBertPreTokenizer>()?;
m.add_class::<pre_tokenizers::Metaspace>()?; m.add_class::<pre_tokenizers::PyMetaspace>()?;
m.add_class::<pre_tokenizers::CharDelimiterSplit>()?; m.add_class::<pre_tokenizers::PyCharDelimiterSplit>()?;
Ok(()) Ok(())
} }

View File

@ -1,31 +1,70 @@
extern crate tokenizers as tk; use std::sync::Arc;
use super::error::ToPyResult;
use super::utils::Container;
use pyo3::exceptions; use pyo3::exceptions;
use pyo3::prelude::*; use pyo3::prelude::*;
use pyo3::types::*; use pyo3::types::*;
use tk::tokenizer::Offsets; use serde::{Deserialize, Deserializer, Serialize, Serializer};
#[pyclass(dict, module = "tokenizers.pre_tokenizers")] use tk::pre_tokenizers::bert::BertPreTokenizer;
pub struct PreTokenizer { use tk::pre_tokenizers::byte_level::ByteLevel;
pub pretok: Container<dyn tk::tokenizer::PreTokenizer>, use tk::pre_tokenizers::delimiter::CharDelimiterSplit;
use tk::pre_tokenizers::metaspace::Metaspace;
use tk::pre_tokenizers::whitespace::{Whitespace, WhitespaceSplit};
use tk::tokenizer::Offsets;
use tk::{PreTokenizedString, PreTokenizer};
use tokenizers as tk;
use super::error::ToPyResult;
#[pyclass(dict, module = "tokenizers.pre_tokenizers", name=PreTokenizer)]
#[derive(Clone)]
pub struct PyPreTokenizer {
pub pretok: Arc<dyn PreTokenizer>,
} }
impl PyPreTokenizer {
pub fn new(pretok: Arc<dyn PreTokenizer>) -> Self {
PyPreTokenizer { pretok }
}
}
impl Serialize for PyPreTokenizer {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
self.pretok.serialize(serializer)
}
}
impl<'de> Deserialize<'de> for PyPreTokenizer {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
Arc::deserialize(deserializer).map(PyPreTokenizer::new)
}
}
#[typetag::serde]
impl PreTokenizer for PyPreTokenizer {
fn pre_tokenize(&self, normalized: &mut PreTokenizedString) -> tk::Result<()> {
self.pretok.pre_tokenize(normalized)
}
}
#[pymethods] #[pymethods]
impl PreTokenizer { impl PyPreTokenizer {
// #[staticmethod] // #[staticmethod]
// fn custom(pretok: PyObject) -> PyResult<Self> { // fn custom(pretok: PyObject) -> PyResult<Self> {
// let py_pretok = PyPreTokenizer::new(pretok)?; // let py_pretok = CustomPreTokenizer::new(pretok)?;
// Ok(PreTokenizer { // Ok(PyPreTokenizer {
// pretok: Container::Owned(Box::new(py_pretok)), // pretok: Arc::new(py_pretok),
// }) // })
// } // }
fn __getstate__(&self, py: Python) -> PyResult<PyObject> { fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
let data = self let data = serde_json::to_string(&self.pretok.as_ref()).map_err(|e| {
.pretok
.execute(|pretok| serde_json::to_string(&pretok))
.map_err(|e| {
exceptions::Exception::py_err(format!( exceptions::Exception::py_err(format!(
"Error while attempting to pickle PreTokenizer: {}", "Error while attempting to pickle PreTokenizer: {}",
e.to_string() e.to_string()
@ -37,13 +76,13 @@ impl PreTokenizer {
fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
match state.extract::<&PyBytes>(py) { match state.extract::<&PyBytes>(py) {
Ok(s) => { Ok(s) => {
self.pretok = let unpickled = serde_json::from_slice(s.as_bytes()).map_err(|e| {
Container::Owned(serde_json::from_slice(s.as_bytes()).map_err(|e| {
exceptions::Exception::py_err(format!( exceptions::Exception::py_err(format!(
"Error while attempting to unpickle PreTokenizer: {}", "Error while attempting to unpickle PreTokenizer: {}",
e.to_string() e.to_string()
)) ))
})?); })?;
self.pretok = unpickled;
Ok(()) Ok(())
} }
Err(e) => Err(e), Err(e) => Err(e),
@ -54,11 +93,7 @@ impl PreTokenizer {
// TODO: Expose the PreTokenizedString // TODO: Expose the PreTokenizedString
let mut pretokenized = tk::tokenizer::PreTokenizedString::from(s); let mut pretokenized = tk::tokenizer::PreTokenizedString::from(s);
ToPyResult( ToPyResult(self.pretok.pre_tokenize(&mut pretokenized)).into_py()?;
self.pretok
.execute(|pretok| pretok.pre_tokenize(&mut pretokenized)),
)
.into_py()?;
Ok(pretokenized Ok(pretokenized
.get_normalized(tk::OffsetReferential::Original) .get_normalized(tk::OffsetReferential::Original)
@ -68,14 +103,14 @@ impl PreTokenizer {
} }
} }
#[pyclass(extends=PreTokenizer, module = "tokenizers.pre_tokenizers")] #[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name=ByteLevel)]
pub struct ByteLevel {} pub struct PyByteLevel {}
#[pymethods] #[pymethods]
impl ByteLevel { impl PyByteLevel {
#[new] #[new]
#[args(kwargs = "**")] #[args(kwargs = "**")]
fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PreTokenizer)> { fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PyPreTokenizer)> {
let mut byte_level = tk::pre_tokenizers::byte_level::ByteLevel::default(); let mut byte_level = ByteLevel::default();
if let Some(kwargs) = kwargs { if let Some(kwargs) = kwargs {
for (key, value) in kwargs { for (key, value) in kwargs {
let key: &str = key.extract()?; let key: &str = key.extract()?;
@ -88,61 +123,50 @@ impl ByteLevel {
} }
} }
Ok(( Ok((PyByteLevel {}, PyPreTokenizer::new(Arc::new(byte_level))))
ByteLevel {},
PreTokenizer {
pretok: Container::Owned(Box::new(byte_level)),
},
))
} }
#[staticmethod] #[staticmethod]
fn alphabet() -> Vec<String> { fn alphabet() -> Vec<String> {
tk::pre_tokenizers::byte_level::ByteLevel::alphabet() ByteLevel::alphabet()
.into_iter() .into_iter()
.map(|c| c.to_string()) .map(|c| c.to_string())
.collect() .collect()
} }
} }
#[pyclass(extends=PreTokenizer, module = "tokenizers.pre_tokenizers")] #[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name=Whitespace)]
pub struct Whitespace {} pub struct PyWhitespace {}
#[pymethods] #[pymethods]
impl Whitespace { impl PyWhitespace {
#[new] #[new]
fn new() -> PyResult<(Self, PreTokenizer)> { fn new() -> PyResult<(Self, PyPreTokenizer)> {
Ok(( Ok((
Whitespace {}, PyWhitespace {},
PreTokenizer { PyPreTokenizer::new(Arc::new(Whitespace::default())),
pretok: Container::Owned(Box::new(
tk::pre_tokenizers::whitespace::Whitespace::default(),
)),
},
)) ))
} }
} }
#[pyclass(extends=PreTokenizer, module = "tokenizers.pre_tokenizers")] #[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name=WhitespaceSplit)]
pub struct WhitespaceSplit {} pub struct PyWhitespaceSplit {}
#[pymethods] #[pymethods]
impl WhitespaceSplit { impl PyWhitespaceSplit {
#[new] #[new]
fn new() -> PyResult<(Self, PreTokenizer)> { fn new() -> PyResult<(Self, PyPreTokenizer)> {
Ok(( Ok((
WhitespaceSplit {}, PyWhitespaceSplit {},
PreTokenizer { PyPreTokenizer::new(Arc::new(WhitespaceSplit)),
pretok: Container::Owned(Box::new(tk::pre_tokenizers::whitespace::WhitespaceSplit)),
},
)) ))
} }
} }
#[pyclass(extends=PreTokenizer, module = "tokenizers.pre_tokenizers")] #[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name=CharDelimiterSplit)]
pub struct CharDelimiterSplit {} pub struct PyCharDelimiterSplit {}
#[pymethods] #[pymethods]
impl CharDelimiterSplit { impl PyCharDelimiterSplit {
#[new] #[new]
pub fn new(delimiter: &str) -> PyResult<(Self, PreTokenizer)> { pub fn new(delimiter: &str) -> PyResult<(Self, PyPreTokenizer)> {
let chr_delimiter = delimiter let chr_delimiter = delimiter
.chars() .chars()
.nth(0) .nth(0)
@ -150,12 +174,8 @@ impl CharDelimiterSplit {
"delimiter must be a single character", "delimiter must be a single character",
))?; ))?;
Ok(( Ok((
CharDelimiterSplit {}, PyCharDelimiterSplit {},
PreTokenizer { PyPreTokenizer::new(Arc::new(CharDelimiterSplit::new(chr_delimiter))),
pretok: Container::Owned(Box::new(
tk::pre_tokenizers::delimiter::CharDelimiterSplit::new(chr_delimiter),
)),
},
)) ))
} }
@ -164,28 +184,26 @@ impl CharDelimiterSplit {
} }
} }
#[pyclass(extends=PreTokenizer, module = "tokenizers.pre_tokenizers")] #[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name=BertPreTokenizer)]
pub struct BertPreTokenizer {} pub struct PyBertPreTokenizer {}
#[pymethods] #[pymethods]
impl BertPreTokenizer { impl PyBertPreTokenizer {
#[new] #[new]
fn new() -> PyResult<(Self, PreTokenizer)> { fn new() -> PyResult<(Self, PyPreTokenizer)> {
Ok(( Ok((
BertPreTokenizer {}, PyBertPreTokenizer {},
PreTokenizer { PyPreTokenizer::new(Arc::new(BertPreTokenizer)),
pretok: Container::Owned(Box::new(tk::pre_tokenizers::bert::BertPreTokenizer)),
},
)) ))
} }
} }
#[pyclass(extends=PreTokenizer, module = "tokenizers.pre_tokenizers")] #[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name=Metaspace)]
pub struct Metaspace {} pub struct PyMetaspace {}
#[pymethods] #[pymethods]
impl Metaspace { impl PyMetaspace {
#[new] #[new]
#[args(kwargs = "**")] #[args(kwargs = "**")]
fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PreTokenizer)> { fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PyPreTokenizer)> {
let mut replacement = '▁'; let mut replacement = '▁';
let mut add_prefix_space = true; let mut add_prefix_space = true;
@ -206,33 +224,25 @@ impl Metaspace {
} }
Ok(( Ok((
Metaspace {}, PyMetaspace {},
PreTokenizer { PyPreTokenizer::new(Arc::new(Metaspace::new(replacement, add_prefix_space))),
pretok: Container::Owned(Box::new(tk::pre_tokenizers::metaspace::Metaspace::new(
replacement,
add_prefix_space,
))),
},
)) ))
} }
} }
// struct PyPreTokenizer { // struct CustomPreTokenizer {
// class: PyObject, // class: PyObject,
// } // }
// //
// impl PyPreTokenizer { // impl CustomPreTokenizer {
// pub fn new(class: PyObject) -> PyResult<Self> { // pub fn new(class: PyObject) -> PyResult<Self> {
// Ok(PyPreTokenizer { class }) // Ok(CustomPreTokenizer { class })
// } // }
// } // }
// //
// #[typetag::serde] // #[typetag::serde]
// impl tk::tokenizer::PreTokenizer for PyPreTokenizer { // impl tk::tokenizer::PreTokenizer for CustomPreTokenizer {
// fn pre_tokenize( // fn pre_tokenize(&self, sentence: &mut NormalizedString) -> tk::Result<Vec<(String, Offsets)>> {
// &self,
// sentence: &mut tk::tokenizer::NormalizedString,
// ) -> Result<Vec<(String, Offsets)>> {
// let gil = Python::acquire_gil(); // let gil = Python::acquire_gil();
// let py = gil.python(); // let py = gil.python();
// //
@ -259,8 +269,8 @@ impl Metaspace {
// } // }
// } // }
// //
// impl Serialize for PyPreTokenizer { // impl Serialize for CustomPreTokenizer {
// fn serialize<S>(&self, _serializer: S) -> std::result::Result<S::Ok, S::Error> // fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
// where // where
// S: Serializer, // S: Serializer,
// { // {
@ -270,11 +280,11 @@ impl Metaspace {
// } // }
// } // }
// //
// impl<'de> Deserialize<'de> for PyPreTokenizer { // impl<'de> Deserialize<'de> for CustomPreTokenizer {
// fn deserialize<D>(_deserializer: D) -> std::result::Result<Self, D::Error> // fn deserialize<D>(_deserializer: D) -> Result<Self, D::Error>
// where // where
// D: Deserializer<'de>, // D: Deserializer<'de>,
// { // {
// unimplemented!("PyPreTokenizer cannot be deserialized") // Err(D::Error::custom("PyDecoder cannot be deserialized"))
// } // }
// } // }

View File

@ -17,7 +17,7 @@ use super::encoding::PyEncoding;
use super::error::{PyError, ToPyResult}; use super::error::{PyError, ToPyResult};
use super::models::PyModel; use super::models::PyModel;
use super::normalizers::PyNormalizer; use super::normalizers::PyNormalizer;
use super::pre_tokenizers::PreTokenizer; use super::pre_tokenizers::PyPreTokenizer;
use super::processors::PostProcessor; use super::processors::PostProcessor;
use super::trainers::PyTrainer; use super::trainers::PyTrainer;
use super::utils::Container; use super::utils::Container;
@ -268,7 +268,7 @@ impl From<PreTokenizedEncodeInput> for tk::tokenizer::EncodeInput {
} }
} }
type TokenizerImpl = Tokenizer<PyModel, PyNormalizer>; type TokenizerImpl = Tokenizer<PyModel, PyNormalizer, PyPreTokenizer>;
#[pyclass(dict, module = "tokenizers")] #[pyclass(dict, module = "tokenizers")]
pub struct PyTokenizer { pub struct PyTokenizer {
@ -717,25 +717,13 @@ impl PyTokenizer {
} }
#[getter] #[getter]
fn get_pre_tokenizer(&self) -> PyResult<Option<PreTokenizer>> { fn get_pre_tokenizer(&self) -> Option<PyPreTokenizer> {
Ok(self self.tokenizer.get_pre_tokenizer().cloned()
.tokenizer
.get_pre_tokenizer()
.map(|pretok| PreTokenizer {
pretok: Container::from_ref(pretok),
}))
} }
#[setter] #[setter]
fn set_pre_tokenizer(&mut self, mut pretok: PyRefMut<PreTokenizer>) -> PyResult<()> { fn set_pre_tokenizer(&mut self, pretok: PyRef<PyPreTokenizer>) {
if let Some(pretok) = pretok.pretok.to_pointer() { self.tokenizer.with_pre_tokenizer(pretok.clone());
self.tokenizer.with_pre_tokenizer(pretok);
Ok(())
} else {
Err(exceptions::Exception::py_err(
"The PreTokenizer is already being used in another Tokenizer",
))
}
} }
#[getter] #[getter]