Remove Container from PostProcessors, replace with Arc.

* prefix the Python types in Rust with Py.
* remove unsound Container wrappers, replace with Arc.
This commit is contained in:
Sebastian Pütz
2020-07-25 19:34:38 +02:00
committed by Anthony MOI
parent 0f81997351
commit 11e86a16c5
3 changed files with 100 additions and 83 deletions

View File

@ -7,7 +7,7 @@ use pyo3::types::*;
use pyo3::PyObjectProtocol;
use tk::models::bpe::BPE;
use tk::tokenizer::{
PaddingDirection, PaddingParams, PaddingStrategy, Tokenizer, TruncationParams,
PaddingDirection, PaddingParams, PaddingStrategy, PostProcessor, Tokenizer, TruncationParams,
TruncationStrategy,
};
use tokenizers as tk;
@ -18,9 +18,9 @@ use super::error::{PyError, ToPyResult};
use super::models::PyModel;
use super::normalizers::PyNormalizer;
use super::pre_tokenizers::PyPreTokenizer;
use super::processors::PostProcessor;
use super::trainers::PyTrainer;
use super::utils::Container;
use crate::processors::PyPostProcessor;
#[pyclass(dict, module = "tokenizers", name=AddedToken)]
pub struct PyAddedToken {
@ -268,9 +268,9 @@ impl From<PreTokenizedEncodeInput> for tk::tokenizer::EncodeInput {
}
}
type TokenizerImpl = Tokenizer<PyModel, PyNormalizer, PyPreTokenizer>;
type TokenizerImpl = Tokenizer<PyModel, PyNormalizer, PyPreTokenizer, PyPostProcessor>;
#[pyclass(dict, module = "tokenizers")]
#[pyclass(dict, module = "tokenizers", name=Tokenizer)]
pub struct PyTokenizer {
tokenizer: TokenizerImpl,
}
@ -362,7 +362,7 @@ impl PyTokenizer {
Ok(self
.tokenizer
.get_post_processor()
.map_or(0, |p| p.as_ref().added_tokens(is_pair)))
.map_or(0, |p| p.added_tokens(is_pair)))
}
#[args(with_added_tokens = true)]
@ -727,25 +727,13 @@ impl PyTokenizer {
}
#[getter]
fn get_post_processor(&self) -> PyResult<Option<PostProcessor>> {
Ok(self
.tokenizer
.get_post_processor()
.map(|processor| PostProcessor {
processor: Container::from_ref(processor),
}))
fn get_post_processor(&self) -> Option<PyPostProcessor> {
self.tokenizer.get_post_processor().cloned()
}
#[setter]
fn set_post_processor(&mut self, mut processor: PyRefMut<PostProcessor>) -> PyResult<()> {
if let Some(processor) = processor.processor.to_pointer() {
self.tokenizer.with_post_processor(processor);
Ok(())
} else {
Err(exceptions::Exception::py_err(
"The Processor is already being used in another Tokenizer",
))
}
fn set_post_processor(&mut self, processor: PyRef<PyPostProcessor>) {
self.tokenizer.with_post_processor(processor.clone());
}
#[getter]