mirror of
https://github.com/mii443/tokenizers.git
synced 2025-12-07 13:18:31 +00:00
Remove Container, changes to PyDecoder, cloneable Tokenizer.
* derive Clone on Tokenizer and AddedVocabulary. * Replace Container with Arc wrapper for Decoders. * Prefix Rust Decoder types with Py. * Rename PyDecoder to CustomDecoder. * Change panic in serializing custom decoder to exception. * Re-enable training with cloneable Tokenizer. * Remove unsound Container, use Arc wrappers instead.
This commit is contained in:
committed by
Anthony MOI
parent
ece6ad9149
commit
d62adf7195
@@ -1,50 +1,85 @@
|
||||
extern crate tokenizers as tk;
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::error::{PyError, ToPyResult};
|
||||
use super::utils::Container;
|
||||
use pyo3::exceptions;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::*;
|
||||
use serde::de::Error;
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
use tk::tokenizer::Result;
|
||||
use tk::decoders::bpe::BPEDecoder;
|
||||
use tk::decoders::byte_level::ByteLevel;
|
||||
use tk::decoders::metaspace::Metaspace;
|
||||
use tk::decoders::wordpiece::WordPiece;
|
||||
use tk::Decoder;
|
||||
use tokenizers as tk;
|
||||
|
||||
#[pyclass(dict, module = "tokenizers.decoders")]
|
||||
pub struct Decoder {
|
||||
pub decoder: Container<dyn tk::tokenizer::Decoder>,
|
||||
use super::error::{PyError, ToPyResult};
|
||||
|
||||
#[pyclass(dict, module = "tokenizers.decoders", name=Decoder)]
|
||||
#[derive(Clone)]
|
||||
pub struct PyDecoder {
|
||||
pub decoder: Arc<dyn Decoder>,
|
||||
}
|
||||
|
||||
impl PyDecoder {
|
||||
pub fn new(decoder: Arc<dyn Decoder>) -> Self {
|
||||
PyDecoder { decoder }
|
||||
}
|
||||
}
|
||||
|
||||
#[typetag::serde]
|
||||
impl Decoder for PyDecoder {
|
||||
fn decode(&self, tokens: Vec<String>) -> tk::Result<String> {
|
||||
self.decoder.decode(tokens)
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for PyDecoder {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
self.decoder.serialize(serializer)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for PyDecoder {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
Ok(PyDecoder {
|
||||
decoder: Arc::deserialize(deserializer)?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl Decoder {
|
||||
impl PyDecoder {
|
||||
#[staticmethod]
|
||||
fn custom(decoder: PyObject) -> PyResult<Self> {
|
||||
let decoder = PyDecoder::new(decoder)?;
|
||||
Ok(Decoder {
|
||||
decoder: Container::Owned(Box::new(decoder)),
|
||||
})
|
||||
let decoder = CustomDecoder::new(decoder).map(Arc::new)?;
|
||||
Ok(PyDecoder::new(decoder))
|
||||
}
|
||||
|
||||
fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
|
||||
let data = self
|
||||
.decoder
|
||||
.execute(|decoder| serde_json::to_string(&decoder))
|
||||
.map_err(|e| {
|
||||
exceptions::Exception::py_err(format!(
|
||||
"Error while attempting to pickle Decoder: {}",
|
||||
e.to_string()
|
||||
))
|
||||
})?;
|
||||
let data = serde_json::to_string(&self.decoder).map_err(|e| {
|
||||
exceptions::Exception::py_err(format!(
|
||||
"Error while attempting to pickle Decoder: {}",
|
||||
e
|
||||
))
|
||||
})?;
|
||||
Ok(PyBytes::new(py, data.as_bytes()).to_object(py))
|
||||
}
|
||||
|
||||
fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
|
||||
match state.extract::<&PyBytes>(py) {
|
||||
Ok(s) => {
|
||||
self.decoder =
|
||||
Container::Owned(serde_json::from_slice(s.as_bytes()).map_err(|e| {
|
||||
exceptions::Exception::py_err(format!(
|
||||
"Error while attempting to unpickle Decoder: {}",
|
||||
e.to_string()
|
||||
))
|
||||
})?);
|
||||
self.decoder = serde_json::from_slice(s.as_bytes()).map_err(|e| {
|
||||
exceptions::Exception::py_err(format!(
|
||||
"Error while attempting to unpickle Decoder: {}",
|
||||
e
|
||||
))
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
Err(e) => Err(e),
|
||||
@@ -52,32 +87,30 @@ impl Decoder {
|
||||
}
|
||||
|
||||
fn decode(&self, tokens: Vec<String>) -> PyResult<String> {
|
||||
ToPyResult(self.decoder.execute(|decoder| decoder.decode(tokens))).into()
|
||||
ToPyResult(self.decoder.decode(tokens)).into()
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass(extends=Decoder, module = "tokenizers.decoders")]
|
||||
pub struct ByteLevel {}
|
||||
#[pyclass(extends=PyDecoder, module = "tokenizers.decoders", name=ByteLevel)]
|
||||
pub struct PyByteLevelDec {}
|
||||
#[pymethods]
|
||||
impl ByteLevel {
|
||||
impl PyByteLevelDec {
|
||||
#[new]
|
||||
fn new() -> PyResult<(Self, Decoder)> {
|
||||
fn new() -> PyResult<(Self, PyDecoder)> {
|
||||
Ok((
|
||||
ByteLevel {},
|
||||
Decoder {
|
||||
decoder: Container::Owned(Box::new(tk::decoders::byte_level::ByteLevel::default())),
|
||||
},
|
||||
PyByteLevelDec {},
|
||||
PyDecoder::new(Arc::new(ByteLevel::default())),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass(extends=Decoder, module = "tokenizers.decoders")]
|
||||
pub struct WordPiece {}
|
||||
#[pyclass(extends=PyDecoder, module = "tokenizers.decoders", name=WordPiece)]
|
||||
pub struct PyWordPieceDec {}
|
||||
#[pymethods]
|
||||
impl WordPiece {
|
||||
impl PyWordPieceDec {
|
||||
#[new]
|
||||
#[args(kwargs = "**")]
|
||||
fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, Decoder)> {
|
||||
fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PyDecoder)> {
|
||||
let mut prefix = String::from("##");
|
||||
let mut cleanup = true;
|
||||
|
||||
@@ -91,23 +124,19 @@ impl WordPiece {
|
||||
}
|
||||
|
||||
Ok((
|
||||
WordPiece {},
|
||||
Decoder {
|
||||
decoder: Container::Owned(Box::new(tk::decoders::wordpiece::WordPiece::new(
|
||||
prefix, cleanup,
|
||||
))),
|
||||
},
|
||||
PyWordPieceDec {},
|
||||
PyDecoder::new(Arc::new(WordPiece::new(prefix, cleanup))),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass(extends=Decoder, module = "tokenizers.decoders")]
|
||||
pub struct Metaspace {}
|
||||
#[pyclass(extends=PyDecoder, module = "tokenizers.decoders", name=Metaspace)]
|
||||
pub struct PyMetaspaceDec {}
|
||||
#[pymethods]
|
||||
impl Metaspace {
|
||||
impl PyMetaspaceDec {
|
||||
#[new]
|
||||
#[args(kwargs = "**")]
|
||||
fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, Decoder)> {
|
||||
fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PyDecoder)> {
|
||||
let mut replacement = '▁';
|
||||
let mut add_prefix_space = true;
|
||||
|
||||
@@ -128,24 +157,19 @@ impl Metaspace {
|
||||
}
|
||||
|
||||
Ok((
|
||||
Metaspace {},
|
||||
Decoder {
|
||||
decoder: Container::Owned(Box::new(tk::decoders::metaspace::Metaspace::new(
|
||||
replacement,
|
||||
add_prefix_space,
|
||||
))),
|
||||
},
|
||||
PyMetaspaceDec {},
|
||||
PyDecoder::new(Arc::new(Metaspace::new(replacement, add_prefix_space))),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass(extends=Decoder, module = "tokenizers.decoders")]
|
||||
pub struct BPEDecoder {}
|
||||
#[pyclass(extends=PyDecoder, module = "tokenizers.decoders", name=BPEDecoder)]
|
||||
pub struct PyBPEDecoder {}
|
||||
#[pymethods]
|
||||
impl BPEDecoder {
|
||||
impl PyBPEDecoder {
|
||||
#[new]
|
||||
#[args(kwargs = "**")]
|
||||
fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, Decoder)> {
|
||||
fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PyDecoder)> {
|
||||
let mut suffix = String::from("</w>");
|
||||
|
||||
if let Some(kwargs) = kwargs {
|
||||
@@ -159,27 +183,25 @@ impl BPEDecoder {
|
||||
}
|
||||
|
||||
Ok((
|
||||
BPEDecoder {},
|
||||
Decoder {
|
||||
decoder: Container::Owned(Box::new(tk::decoders::bpe::BPEDecoder::new(suffix))),
|
||||
},
|
||||
PyBPEDecoder {},
|
||||
PyDecoder::new(Arc::new(BPEDecoder::new(suffix))),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
struct PyDecoder {
|
||||
struct CustomDecoder {
|
||||
class: PyObject,
|
||||
}
|
||||
|
||||
impl PyDecoder {
|
||||
impl CustomDecoder {
|
||||
pub fn new(class: PyObject) -> PyResult<Self> {
|
||||
Ok(PyDecoder { class })
|
||||
Ok(CustomDecoder { class })
|
||||
}
|
||||
}
|
||||
|
||||
#[typetag::serde]
|
||||
impl tk::tokenizer::Decoder for PyDecoder {
|
||||
fn decode(&self, tokens: Vec<String>) -> Result<String> {
|
||||
impl Decoder for CustomDecoder {
|
||||
fn decode(&self, tokens: Vec<String>) -> tk::Result<String> {
|
||||
let gil = Python::acquire_gil();
|
||||
let py = gil.python();
|
||||
|
||||
@@ -199,7 +221,7 @@ impl tk::tokenizer::Decoder for PyDecoder {
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for PyDecoder {
|
||||
impl Serialize for CustomDecoder {
|
||||
fn serialize<S>(&self, _serializer: S) -> std::result::Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
@@ -210,11 +232,11 @@ impl Serialize for PyDecoder {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for PyDecoder {
|
||||
impl<'de> Deserialize<'de> for CustomDecoder {
|
||||
fn deserialize<D>(_deserializer: D) -> std::result::Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
unimplemented!("PyDecoder cannot be deserialized")
|
||||
Err(D::Error::custom("PyDecoder cannot be deserialized"))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,7 +10,6 @@ mod processors;
|
||||
mod token;
|
||||
mod tokenizer;
|
||||
mod trainers;
|
||||
mod utils;
|
||||
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::wrap_pymodule;
|
||||
@@ -71,11 +70,11 @@ fn pre_tokenizers(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||
/// Decoders Module
|
||||
#[pymodule]
|
||||
fn decoders(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||
m.add_class::<decoders::Decoder>()?;
|
||||
m.add_class::<decoders::ByteLevel>()?;
|
||||
m.add_class::<decoders::WordPiece>()?;
|
||||
m.add_class::<decoders::Metaspace>()?;
|
||||
m.add_class::<decoders::BPEDecoder>()?;
|
||||
m.add_class::<decoders::PyDecoder>()?;
|
||||
m.add_class::<decoders::PyByteLevelDec>()?;
|
||||
m.add_class::<decoders::PyWordPieceDec>()?;
|
||||
m.add_class::<decoders::PyMetaspaceDec>()?;
|
||||
m.add_class::<decoders::PyBPEDecoder>()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
@@ -12,14 +12,13 @@ use tk::tokenizer::{
|
||||
};
|
||||
use tokenizers as tk;
|
||||
|
||||
use super::decoders::Decoder;
|
||||
use super::decoders::PyDecoder;
|
||||
use super::encoding::PyEncoding;
|
||||
use super::error::{PyError, ToPyResult};
|
||||
use super::models::PyModel;
|
||||
use super::normalizers::PyNormalizer;
|
||||
use super::pre_tokenizers::PyPreTokenizer;
|
||||
use super::trainers::PyTrainer;
|
||||
use super::utils::Container;
|
||||
use crate::processors::PyPostProcessor;
|
||||
|
||||
#[pyclass(dict, module = "tokenizers", name=AddedToken)]
|
||||
@@ -268,9 +267,10 @@ impl From<PreTokenizedEncodeInput> for tk::tokenizer::EncodeInput {
|
||||
}
|
||||
}
|
||||
|
||||
type TokenizerImpl = Tokenizer<PyModel, PyNormalizer, PyPreTokenizer, PyPostProcessor>;
|
||||
type TokenizerImpl = Tokenizer<PyModel, PyNormalizer, PyPreTokenizer, PyPostProcessor, PyDecoder>;
|
||||
|
||||
#[pyclass(dict, module = "tokenizers", name=Tokenizer)]
|
||||
#[derive(Clone)]
|
||||
pub struct PyTokenizer {
|
||||
tokenizer: TokenizerImpl,
|
||||
}
|
||||
@@ -666,15 +666,13 @@ impl PyTokenizer {
|
||||
Ok(self.tokenizer.add_special_tokens(&tokens))
|
||||
}
|
||||
|
||||
fn train(&mut self, _trainer: &PyTrainer, _files: Vec<String>) -> PyResult<()> {
|
||||
// TODO enable training once Tokenizer derives Clone
|
||||
// self.tokenizer = self.tokenizer.clone().train(trainer, files).map_err(|e|
|
||||
// exceptions::Exception::py_err(format!("{}", e))
|
||||
// )?;
|
||||
// Ok(())
|
||||
Err(exceptions::NotImplementedError::py_err(
|
||||
"Training currently disabled",
|
||||
))
|
||||
fn train(&mut self, trainer: &PyTrainer, files: Vec<String>) -> PyResult<()> {
|
||||
self.tokenizer = self
|
||||
.tokenizer
|
||||
.clone()
|
||||
.train(trainer, files)
|
||||
.map_err(|e| exceptions::Exception::py_err(format!("{}", e)))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[args(pair = "None", add_special_tokens = true)]
|
||||
@@ -737,21 +735,12 @@ impl PyTokenizer {
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_decoder(&self) -> PyResult<Option<Decoder>> {
|
||||
Ok(self.tokenizer.get_decoder().map(|decoder| Decoder {
|
||||
decoder: Container::from_ref(decoder),
|
||||
}))
|
||||
fn get_decoder(&self) -> Option<PyDecoder> {
|
||||
self.tokenizer.get_decoder().cloned()
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_decoder(&mut self, mut decoder: PyRefMut<Decoder>) -> PyResult<()> {
|
||||
if let Some(decoder) = decoder.decoder.to_pointer() {
|
||||
self.tokenizer.with_decoder(decoder);
|
||||
Ok(())
|
||||
} else {
|
||||
Err(exceptions::Exception::py_err(
|
||||
"The Decoder is already being used in another Tokenizer",
|
||||
))
|
||||
}
|
||||
fn set_decoder(&mut self, decoder: PyRef<PyDecoder>) {
|
||||
self.tokenizer.with_decoder(decoder.clone());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,64 +0,0 @@
|
||||
/// A Container type
|
||||
///
|
||||
/// Provides an interface to allow transfer of ownership between Python and Rust.
|
||||
/// It either contains a Box with full ownership of the content, or a pointer to the content.
|
||||
///
|
||||
/// The main goal here is to allow Python calling into Rust to initialize some objects. Later
|
||||
/// these objects may need to be used by Rust who will expect to take ownership. Since Python
|
||||
/// does not allow any sort of ownership transfer, it will keep a reference to this object
|
||||
/// until it gets cleaned up by the GC. In this case, we actually give the ownership to Rust,
|
||||
/// and just keep a pointer in the Python object.
|
||||
pub enum Container<T: ?Sized> {
|
||||
Owned(Box<T>),
|
||||
Pointer(*mut T),
|
||||
}
|
||||
|
||||
impl<T> Container<T>
|
||||
where
|
||||
T: ?Sized,
|
||||
{
|
||||
pub fn from_ref(reference: &Box<T>) -> Self {
|
||||
let content: *const T = &**reference;
|
||||
Container::Pointer(content as *mut _)
|
||||
}
|
||||
|
||||
/// Consumes ourself and return the Boxed element if we have the ownership, None otherwise.
|
||||
pub fn take(self) -> Option<Box<T>> {
|
||||
match self {
|
||||
Container::Owned(obj) => Some(obj),
|
||||
Container::Pointer(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the owned T, keeping a Pointer to it if we currently own it. None otherwise
|
||||
pub fn to_pointer(&mut self) -> Option<Box<T>> {
|
||||
if let Container::Owned(_) = self {
|
||||
unsafe {
|
||||
let old_container = std::ptr::read(self);
|
||||
let ptr = Box::into_raw(old_container.take().unwrap());
|
||||
let new_container = Container::Pointer(ptr);
|
||||
std::ptr::write(self, new_container);
|
||||
|
||||
Some(Box::from_raw(ptr))
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn execute<F, U>(&self, closure: F) -> U
|
||||
where
|
||||
F: FnOnce(&Box<T>) -> U,
|
||||
{
|
||||
match self {
|
||||
Container::Owned(val) => closure(val),
|
||||
Container::Pointer(ptr) => unsafe {
|
||||
let val = Box::from_raw(*ptr);
|
||||
let res = closure(&val);
|
||||
// We call this to make sure we don't drop the Box
|
||||
Box::into_raw(val);
|
||||
res
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user