Remove Container, changes to PyDecoder, cloneable Tokenizer.

* derive Clone on Tokenizer and AddedVocabulary.
* Replace Container with Arc wrapper for Decoders.
* Prefix Rust Decoder types with Py.
* Rename PyDecoder to CustomDecoder.
* Change panic in serializing custom decoder to exception.
* Re-enable training with cloneable Tokenizer.
* Remove unsound Container, use Arc wrappers instead.
This commit is contained in:
Sebastian Pütz
2020-07-25 20:07:24 +02:00
committed by Anthony MOI
parent ece6ad9149
commit d62adf7195
6 changed files with 117 additions and 169 deletions

View File

@@ -1,50 +1,85 @@
extern crate tokenizers as tk;
use std::sync::Arc;
use super::error::{PyError, ToPyResult};
use super::utils::Container;
use pyo3::exceptions;
use pyo3::prelude::*;
use pyo3::types::*;
use serde::de::Error;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use tk::tokenizer::Result;
use tk::decoders::bpe::BPEDecoder;
use tk::decoders::byte_level::ByteLevel;
use tk::decoders::metaspace::Metaspace;
use tk::decoders::wordpiece::WordPiece;
use tk::Decoder;
use tokenizers as tk;
#[pyclass(dict, module = "tokenizers.decoders")]
pub struct Decoder {
pub decoder: Container<dyn tk::tokenizer::Decoder>,
use super::error::{PyError, ToPyResult};
#[pyclass(dict, module = "tokenizers.decoders", name=Decoder)]
#[derive(Clone)]
pub struct PyDecoder {
pub decoder: Arc<dyn Decoder>,
}
impl PyDecoder {
pub fn new(decoder: Arc<dyn Decoder>) -> Self {
PyDecoder { decoder }
}
}
#[typetag::serde]
impl Decoder for PyDecoder {
fn decode(&self, tokens: Vec<String>) -> tk::Result<String> {
self.decoder.decode(tokens)
}
}
impl Serialize for PyDecoder {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
self.decoder.serialize(serializer)
}
}
impl<'de> Deserialize<'de> for PyDecoder {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
Ok(PyDecoder {
decoder: Arc::deserialize(deserializer)?,
})
}
}
#[pymethods]
impl Decoder {
impl PyDecoder {
#[staticmethod]
fn custom(decoder: PyObject) -> PyResult<Self> {
let decoder = PyDecoder::new(decoder)?;
Ok(Decoder {
decoder: Container::Owned(Box::new(decoder)),
})
let decoder = CustomDecoder::new(decoder).map(Arc::new)?;
Ok(PyDecoder::new(decoder))
}
fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
let data = self
.decoder
.execute(|decoder| serde_json::to_string(&decoder))
.map_err(|e| {
exceptions::Exception::py_err(format!(
"Error while attempting to pickle Decoder: {}",
e.to_string()
))
})?;
let data = serde_json::to_string(&self.decoder).map_err(|e| {
exceptions::Exception::py_err(format!(
"Error while attempting to pickle Decoder: {}",
e
))
})?;
Ok(PyBytes::new(py, data.as_bytes()).to_object(py))
}
fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
match state.extract::<&PyBytes>(py) {
Ok(s) => {
self.decoder =
Container::Owned(serde_json::from_slice(s.as_bytes()).map_err(|e| {
exceptions::Exception::py_err(format!(
"Error while attempting to unpickle Decoder: {}",
e.to_string()
))
})?);
self.decoder = serde_json::from_slice(s.as_bytes()).map_err(|e| {
exceptions::Exception::py_err(format!(
"Error while attempting to unpickle Decoder: {}",
e
))
})?;
Ok(())
}
Err(e) => Err(e),
@@ -52,32 +87,30 @@ impl Decoder {
}
fn decode(&self, tokens: Vec<String>) -> PyResult<String> {
ToPyResult(self.decoder.execute(|decoder| decoder.decode(tokens))).into()
ToPyResult(self.decoder.decode(tokens)).into()
}
}
#[pyclass(extends=Decoder, module = "tokenizers.decoders")]
pub struct ByteLevel {}
#[pyclass(extends=PyDecoder, module = "tokenizers.decoders", name=ByteLevel)]
pub struct PyByteLevelDec {}
#[pymethods]
impl ByteLevel {
impl PyByteLevelDec {
#[new]
fn new() -> PyResult<(Self, Decoder)> {
fn new() -> PyResult<(Self, PyDecoder)> {
Ok((
ByteLevel {},
Decoder {
decoder: Container::Owned(Box::new(tk::decoders::byte_level::ByteLevel::default())),
},
PyByteLevelDec {},
PyDecoder::new(Arc::new(ByteLevel::default())),
))
}
}
#[pyclass(extends=Decoder, module = "tokenizers.decoders")]
pub struct WordPiece {}
#[pyclass(extends=PyDecoder, module = "tokenizers.decoders", name=WordPiece)]
pub struct PyWordPieceDec {}
#[pymethods]
impl WordPiece {
impl PyWordPieceDec {
#[new]
#[args(kwargs = "**")]
fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, Decoder)> {
fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PyDecoder)> {
let mut prefix = String::from("##");
let mut cleanup = true;
@@ -91,23 +124,19 @@ impl WordPiece {
}
Ok((
WordPiece {},
Decoder {
decoder: Container::Owned(Box::new(tk::decoders::wordpiece::WordPiece::new(
prefix, cleanup,
))),
},
PyWordPieceDec {},
PyDecoder::new(Arc::new(WordPiece::new(prefix, cleanup))),
))
}
}
#[pyclass(extends=Decoder, module = "tokenizers.decoders")]
pub struct Metaspace {}
#[pyclass(extends=PyDecoder, module = "tokenizers.decoders", name=Metaspace)]
pub struct PyMetaspaceDec {}
#[pymethods]
impl Metaspace {
impl PyMetaspaceDec {
#[new]
#[args(kwargs = "**")]
fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, Decoder)> {
fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PyDecoder)> {
let mut replacement = '▁';
let mut add_prefix_space = true;
@@ -128,24 +157,19 @@ impl Metaspace {
}
Ok((
Metaspace {},
Decoder {
decoder: Container::Owned(Box::new(tk::decoders::metaspace::Metaspace::new(
replacement,
add_prefix_space,
))),
},
PyMetaspaceDec {},
PyDecoder::new(Arc::new(Metaspace::new(replacement, add_prefix_space))),
))
}
}
#[pyclass(extends=Decoder, module = "tokenizers.decoders")]
pub struct BPEDecoder {}
#[pyclass(extends=PyDecoder, module = "tokenizers.decoders", name=BPEDecoder)]
pub struct PyBPEDecoder {}
#[pymethods]
impl BPEDecoder {
impl PyBPEDecoder {
#[new]
#[args(kwargs = "**")]
fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, Decoder)> {
fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PyDecoder)> {
let mut suffix = String::from("</w>");
if let Some(kwargs) = kwargs {
@@ -159,27 +183,25 @@ impl BPEDecoder {
}
Ok((
BPEDecoder {},
Decoder {
decoder: Container::Owned(Box::new(tk::decoders::bpe::BPEDecoder::new(suffix))),
},
PyBPEDecoder {},
PyDecoder::new(Arc::new(BPEDecoder::new(suffix))),
))
}
}
struct PyDecoder {
struct CustomDecoder {
class: PyObject,
}
impl PyDecoder {
impl CustomDecoder {
pub fn new(class: PyObject) -> PyResult<Self> {
Ok(PyDecoder { class })
Ok(CustomDecoder { class })
}
}
#[typetag::serde]
impl tk::tokenizer::Decoder for PyDecoder {
fn decode(&self, tokens: Vec<String>) -> Result<String> {
impl Decoder for CustomDecoder {
fn decode(&self, tokens: Vec<String>) -> tk::Result<String> {
let gil = Python::acquire_gil();
let py = gil.python();
@@ -199,7 +221,7 @@ impl tk::tokenizer::Decoder for PyDecoder {
}
}
impl Serialize for PyDecoder {
impl Serialize for CustomDecoder {
fn serialize<S>(&self, _serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: Serializer,
@@ -210,11 +232,11 @@ impl Serialize for PyDecoder {
}
}
impl<'de> Deserialize<'de> for PyDecoder {
impl<'de> Deserialize<'de> for CustomDecoder {
fn deserialize<D>(_deserializer: D) -> std::result::Result<Self, D::Error>
where
D: Deserializer<'de>,
{
unimplemented!("PyDecoder cannot be deserialized")
Err(D::Error::custom("PyDecoder cannot be deserialized"))
}
}

View File

@@ -10,7 +10,6 @@ mod processors;
mod token;
mod tokenizer;
mod trainers;
mod utils;
use pyo3::prelude::*;
use pyo3::wrap_pymodule;
@@ -71,11 +70,11 @@ fn pre_tokenizers(_py: Python, m: &PyModule) -> PyResult<()> {
/// Decoders Module
#[pymodule]
fn decoders(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<decoders::Decoder>()?;
m.add_class::<decoders::ByteLevel>()?;
m.add_class::<decoders::WordPiece>()?;
m.add_class::<decoders::Metaspace>()?;
m.add_class::<decoders::BPEDecoder>()?;
m.add_class::<decoders::PyDecoder>()?;
m.add_class::<decoders::PyByteLevelDec>()?;
m.add_class::<decoders::PyWordPieceDec>()?;
m.add_class::<decoders::PyMetaspaceDec>()?;
m.add_class::<decoders::PyBPEDecoder>()?;
Ok(())
}

View File

@@ -12,14 +12,13 @@ use tk::tokenizer::{
};
use tokenizers as tk;
use super::decoders::Decoder;
use super::decoders::PyDecoder;
use super::encoding::PyEncoding;
use super::error::{PyError, ToPyResult};
use super::models::PyModel;
use super::normalizers::PyNormalizer;
use super::pre_tokenizers::PyPreTokenizer;
use super::trainers::PyTrainer;
use super::utils::Container;
use crate::processors::PyPostProcessor;
#[pyclass(dict, module = "tokenizers", name=AddedToken)]
@@ -268,9 +267,10 @@ impl From<PreTokenizedEncodeInput> for tk::tokenizer::EncodeInput {
}
}
type TokenizerImpl = Tokenizer<PyModel, PyNormalizer, PyPreTokenizer, PyPostProcessor>;
type TokenizerImpl = Tokenizer<PyModel, PyNormalizer, PyPreTokenizer, PyPostProcessor, PyDecoder>;
#[pyclass(dict, module = "tokenizers", name=Tokenizer)]
#[derive(Clone)]
pub struct PyTokenizer {
tokenizer: TokenizerImpl,
}
@@ -666,15 +666,13 @@ impl PyTokenizer {
Ok(self.tokenizer.add_special_tokens(&tokens))
}
fn train(&mut self, _trainer: &PyTrainer, _files: Vec<String>) -> PyResult<()> {
// TODO enable training once Tokenizer derives Clone
// self.tokenizer = self.tokenizer.clone().train(trainer, files).map_err(|e|
// exceptions::Exception::py_err(format!("{}", e))
// )?;
// Ok(())
Err(exceptions::NotImplementedError::py_err(
"Training currently disabled",
))
fn train(&mut self, trainer: &PyTrainer, files: Vec<String>) -> PyResult<()> {
self.tokenizer = self
.tokenizer
.clone()
.train(trainer, files)
.map_err(|e| exceptions::Exception::py_err(format!("{}", e)))?;
Ok(())
}
#[args(pair = "None", add_special_tokens = true)]
@@ -737,21 +735,12 @@ impl PyTokenizer {
}
#[getter]
fn get_decoder(&self) -> PyResult<Option<Decoder>> {
Ok(self.tokenizer.get_decoder().map(|decoder| Decoder {
decoder: Container::from_ref(decoder),
}))
fn get_decoder(&self) -> Option<PyDecoder> {
self.tokenizer.get_decoder().cloned()
}
#[setter]
fn set_decoder(&mut self, mut decoder: PyRefMut<Decoder>) -> PyResult<()> {
if let Some(decoder) = decoder.decoder.to_pointer() {
self.tokenizer.with_decoder(decoder);
Ok(())
} else {
Err(exceptions::Exception::py_err(
"The Decoder is already being used in another Tokenizer",
))
}
fn set_decoder(&mut self, decoder: PyRef<PyDecoder>) {
self.tokenizer.with_decoder(decoder.clone());
}
}

View File

@@ -1,64 +0,0 @@
/// A Container type
///
/// Provides an interface to allow transfer of ownership between Python and Rust.
/// It either contains a Box with full ownership of the content, or a pointer to the content.
///
/// The main goal here is to allow Python calling into Rust to initialize some objects. Later
/// these objects may need to be used by Rust who will expect to take ownership. Since Python
/// does not allow any sort of ownership transfer, it will keep a reference to this object
/// until it gets cleaned up by the GC. In this case, we actually give the ownership to Rust,
/// and just keep a pointer in the Python object.
pub enum Container<T: ?Sized> {
Owned(Box<T>),
Pointer(*mut T),
}
impl<T> Container<T>
where
T: ?Sized,
{
pub fn from_ref(reference: &Box<T>) -> Self {
let content: *const T = &**reference;
Container::Pointer(content as *mut _)
}
/// Consumes ourself and return the Boxed element if we have the ownership, None otherwise.
pub fn take(self) -> Option<Box<T>> {
match self {
Container::Owned(obj) => Some(obj),
Container::Pointer(_) => None,
}
}
/// Return the owned T, keeping a Pointer to it if we currently own it. None otherwise
pub fn to_pointer(&mut self) -> Option<Box<T>> {
if let Container::Owned(_) = self {
unsafe {
let old_container = std::ptr::read(self);
let ptr = Box::into_raw(old_container.take().unwrap());
let new_container = Container::Pointer(ptr);
std::ptr::write(self, new_container);
Some(Box::from_raw(ptr))
}
} else {
None
}
}
pub fn execute<F, U>(&self, closure: F) -> U
where
F: FnOnce(&Box<T>) -> U,
{
match self {
Container::Owned(val) => closure(val),
Container::Pointer(ptr) => unsafe {
let val = Box::from_raw(*ptr);
let res = closure(&val);
// We call this to make sure we don't drop the Box
Box::into_raw(val);
res
},
}
}
}