mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Using serde (serde_pyo3) to get __str__ and __repr__ easily.
This commit is contained in:
@ -18,6 +18,7 @@ pyo3 = { version = "0.21" }
|
||||
numpy = "0.21"
|
||||
ndarray = "0.15"
|
||||
itertools = "0.12"
|
||||
serde_pyo3 = { git = "https://github.com/Narsil/serde_pyo3" }
|
||||
|
||||
[dependencies.tokenizers]
|
||||
path = "../../tokenizers"
|
||||
|
@ -29,8 +29,8 @@ use super::error::ToPyResult;
|
||||
/// a Decoder will return an instance of this class when instantiated.
|
||||
#[pyclass(dict, module = "tokenizers.decoders", name = "Decoder", subclass)]
|
||||
#[derive(Clone, Deserialize, Serialize)]
|
||||
#[serde(transparent)]
|
||||
pub struct PyDecoder {
|
||||
#[serde(flatten)]
|
||||
pub(crate) decoder: PyDecoderWrapper,
|
||||
}
|
||||
|
||||
|
@ -26,8 +26,8 @@ use super::error::{deprecation_warning, ToPyResult};
|
||||
/// This class cannot be constructed directly. Please use one of the concrete models.
|
||||
#[pyclass(module = "tokenizers.models", name = "Model", subclass)]
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
#[serde(transparent)]
|
||||
pub struct PyModel {
|
||||
#[serde(flatten)]
|
||||
pub model: Arc<RwLock<ModelWrapper>>,
|
||||
}
|
||||
|
||||
|
@ -44,8 +44,8 @@ impl PyNormalizedStringMut<'_> {
|
||||
/// Normalizer will return an instance of this class when instantiated.
|
||||
#[pyclass(dict, module = "tokenizers.normalizers", name = "Normalizer", subclass)]
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
#[serde(transparent)]
|
||||
pub struct PyNormalizer {
|
||||
#[serde(flatten)]
|
||||
pub(crate) normalizer: PyNormalizerTypeWrapper,
|
||||
}
|
||||
|
||||
|
@ -35,8 +35,8 @@ use super::utils::*;
|
||||
subclass
|
||||
)]
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
#[serde(transparent)]
|
||||
pub struct PyPreTokenizer {
|
||||
#[serde(flatten)]
|
||||
pub(crate) pretok: PyPreTokenizerTypeWrapper,
|
||||
}
|
||||
|
||||
|
@ -28,8 +28,8 @@ use tokenizers as tk;
|
||||
subclass
|
||||
)]
|
||||
#[derive(Clone, Deserialize, Serialize)]
|
||||
#[serde(transparent)]
|
||||
pub struct PyPostProcessor {
|
||||
#[serde(flatten)]
|
||||
pub processor: Arc<PostProcessorWrapper>,
|
||||
}
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
use serde::Serialize;
|
||||
use std::collections::{hash_map::DefaultHasher, HashMap};
|
||||
use std::hash::{Hash, Hasher};
|
||||
|
||||
@ -462,7 +463,8 @@ type Tokenizer = TokenizerImpl<PyModel, PyNormalizer, PyPreTokenizer, PyPostProc
|
||||
/// The core algorithm that this :obj:`Tokenizer` should be using.
|
||||
///
|
||||
#[pyclass(dict, module = "tokenizers", name = "Tokenizer")]
|
||||
#[derive(Clone)]
|
||||
#[derive(Clone, Serialize)]
|
||||
#[serde(transparent)]
|
||||
pub struct PyTokenizer {
|
||||
tokenizer: Tokenizer,
|
||||
}
|
||||
@ -638,6 +640,11 @@ impl PyTokenizer {
|
||||
ToPyResult(self.tokenizer.save(path, pretty)).into()
|
||||
}
|
||||
|
||||
#[pyo3(signature = ())]
|
||||
fn repr(&self) -> PyResult<String> {
|
||||
serde_pyo3::to_string(self).map_err(|e| exceptions::PyException::new_err(e.to_string()))
|
||||
}
|
||||
|
||||
/// Return the number of special tokens that would be added for single/pair sentences.
|
||||
/// :param is_pair: Boolean indicating if the input would be a single sentence or a pair
|
||||
/// :return:
|
||||
@ -1434,4 +1441,16 @@ mod test {
|
||||
|
||||
Tokenizer::from_file(&tmp).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serde_pyo3() {
|
||||
let mut tokenizer = Tokenizer::new(PyModel::from(BPE::default()));
|
||||
tokenizer.with_normalizer(PyNormalizer::new(PyNormalizerTypeWrapper::Sequence(vec![
|
||||
Arc::new(RwLock::new(NFKC.into())),
|
||||
Arc::new(RwLock::new(Lowercase.into())),
|
||||
])));
|
||||
|
||||
let output = serde_pyo3::to_string(&tokenizer).unwrap();
|
||||
assert_eq!(output, "");
|
||||
}
|
||||
}
|
||||
|
@ -16,8 +16,8 @@ use tokenizers as tk;
|
||||
/// Trainer will return an instance of this class when instantiated.
|
||||
#[pyclass(module = "tokenizers.trainers", name = "Trainer", subclass)]
|
||||
#[derive(Clone, Deserialize, Serialize)]
|
||||
#[serde(transparent)]
|
||||
pub struct PyTrainer {
|
||||
#[serde(flatten)]
|
||||
pub trainer: Arc<RwLock<TrainerWrapper>>,
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user