Remove Container, changes to PyDecoder, cloneable Tokenizer.

* derive Clone on Tokenizer and AddedVocabulary.
* Replace Container with Arc wrapper for Decoders.
* Prefix Rust Decoder types with Py.
* Rename PyDecoder to CustomDecoder.
* Change panic in serializing custom decoder to exception.
* Re-enable training with cloneable Tokenizer.
* Remove unsound Container, use Arc wrappers instead.
This commit is contained in:
Sebastian Pütz
2020-07-25 20:07:24 +02:00
committed by Anthony MOI
parent ece6ad9149
commit d62adf7195
6 changed files with 117 additions and 169 deletions

View File

@ -12,14 +12,13 @@ use tk::tokenizer::{
};
use tokenizers as tk;
use super::decoders::Decoder;
use super::decoders::PyDecoder;
use super::encoding::PyEncoding;
use super::error::{PyError, ToPyResult};
use super::models::PyModel;
use super::normalizers::PyNormalizer;
use super::pre_tokenizers::PyPreTokenizer;
use super::trainers::PyTrainer;
use super::utils::Container;
use crate::processors::PyPostProcessor;
#[pyclass(dict, module = "tokenizers", name=AddedToken)]
@ -268,9 +267,10 @@ impl From<PreTokenizedEncodeInput> for tk::tokenizer::EncodeInput {
}
}
type TokenizerImpl = Tokenizer<PyModel, PyNormalizer, PyPreTokenizer, PyPostProcessor>;
type TokenizerImpl = Tokenizer<PyModel, PyNormalizer, PyPreTokenizer, PyPostProcessor, PyDecoder>;
#[pyclass(dict, module = "tokenizers", name=Tokenizer)]
#[derive(Clone)]
pub struct PyTokenizer {
tokenizer: TokenizerImpl,
}
@ -666,15 +666,13 @@ impl PyTokenizer {
Ok(self.tokenizer.add_special_tokens(&tokens))
}
fn train(&mut self, _trainer: &PyTrainer, _files: Vec<String>) -> PyResult<()> {
// TODO enable training once Tokenizer derives Clone
// self.tokenizer = self.tokenizer.clone().train(trainer, files).map_err(|e|
// exceptions::Exception::py_err(format!("{}", e))
// )?;
// Ok(())
Err(exceptions::NotImplementedError::py_err(
"Training currently disabled",
))
fn train(&mut self, trainer: &PyTrainer, files: Vec<String>) -> PyResult<()> {
self.tokenizer = self
.tokenizer
.clone()
.train(trainer, files)
.map_err(|e| exceptions::Exception::py_err(format!("{}", e)))?;
Ok(())
}
#[args(pair = "None", add_special_tokens = true)]
@ -737,21 +735,12 @@ impl PyTokenizer {
}
#[getter]
fn get_decoder(&self) -> PyResult<Option<Decoder>> {
Ok(self.tokenizer.get_decoder().map(|decoder| Decoder {
decoder: Container::from_ref(decoder),
}))
fn get_decoder(&self) -> Option<PyDecoder> {
self.tokenizer.get_decoder().cloned()
}
#[setter]
fn set_decoder(&mut self, mut decoder: PyRefMut<Decoder>) -> PyResult<()> {
if let Some(decoder) = decoder.decoder.to_pointer() {
self.tokenizer.with_decoder(decoder);
Ok(())
} else {
Err(exceptions::Exception::py_err(
"The Decoder is already being used in another Tokenizer",
))
}
fn set_decoder(&mut self, decoder: PyRef<PyDecoder>) {
self.tokenizer.with_decoder(decoder.clone());
}
}