Using serde (serde_pyo3) to get __str__ and __repr__ easily.

This commit is contained in:
Nicolas Patry
2024-08-02 18:41:54 +02:00
parent 7415e28536
commit 86138337fc
8 changed files with 27 additions and 7 deletions

View File

@ -18,6 +18,7 @@ pyo3 = { version = "0.21" }
numpy = "0.21"
ndarray = "0.15"
itertools = "0.12"
serde_pyo3 = { git = "https://github.com/Narsil/serde_pyo3" }
[dependencies.tokenizers]
path = "../../tokenizers"

View File

@ -29,8 +29,8 @@ use super::error::ToPyResult;
/// a Decoder will return an instance of this class when instantiated.
#[pyclass(dict, module = "tokenizers.decoders", name = "Decoder", subclass)]
#[derive(Clone, Deserialize, Serialize)]
#[serde(transparent)]
pub struct PyDecoder {
#[serde(flatten)]
pub(crate) decoder: PyDecoderWrapper,
}

View File

@ -26,8 +26,8 @@ use super::error::{deprecation_warning, ToPyResult};
/// This class cannot be constructed directly. Please use one of the concrete models.
#[pyclass(module = "tokenizers.models", name = "Model", subclass)]
#[derive(Clone, Serialize, Deserialize)]
#[serde(transparent)]
pub struct PyModel {
#[serde(flatten)]
pub model: Arc<RwLock<ModelWrapper>>,
}

View File

@ -44,8 +44,8 @@ impl PyNormalizedStringMut<'_> {
/// Normalizer will return an instance of this class when instantiated.
#[pyclass(dict, module = "tokenizers.normalizers", name = "Normalizer", subclass)]
#[derive(Clone, Serialize, Deserialize)]
#[serde(transparent)]
pub struct PyNormalizer {
#[serde(flatten)]
pub(crate) normalizer: PyNormalizerTypeWrapper,
}

View File

@ -35,8 +35,8 @@ use super::utils::*;
subclass
)]
#[derive(Clone, Serialize, Deserialize)]
#[serde(transparent)]
pub struct PyPreTokenizer {
#[serde(flatten)]
pub(crate) pretok: PyPreTokenizerTypeWrapper,
}

View File

@ -28,8 +28,8 @@ use tokenizers as tk;
subclass
)]
#[derive(Clone, Deserialize, Serialize)]
#[serde(transparent)]
pub struct PyPostProcessor {
#[serde(flatten)]
pub processor: Arc<PostProcessorWrapper>,
}

View File

@ -1,3 +1,4 @@
use serde::Serialize;
use std::collections::{hash_map::DefaultHasher, HashMap};
use std::hash::{Hash, Hasher};
@ -462,7 +463,8 @@ type Tokenizer = TokenizerImpl<PyModel, PyNormalizer, PyPreTokenizer, PyPostProc
/// The core algorithm that this :obj:`Tokenizer` should be using.
///
#[pyclass(dict, module = "tokenizers", name = "Tokenizer")]
#[derive(Clone)]
#[derive(Clone, Serialize)]
#[serde(transparent)]
pub struct PyTokenizer {
tokenizer: Tokenizer,
}
@ -638,6 +640,11 @@ impl PyTokenizer {
ToPyResult(self.tokenizer.save(path, pretty)).into()
}
#[pyo3(signature = ())]
fn repr(&self) -> PyResult<String> {
serde_pyo3::to_string(self).map_err(|e| exceptions::PyException::new_err(e.to_string()))
}
/// Return the number of special tokens that would be added for single/pair sentences.
/// :param is_pair: Boolean indicating if the input would be a single sentence or a pair
/// :return:
@ -1434,4 +1441,16 @@ mod test {
Tokenizer::from_file(&tmp).unwrap();
}
#[test]
fn serde_pyo3() {
let mut tokenizer = Tokenizer::new(PyModel::from(BPE::default()));
tokenizer.with_normalizer(PyNormalizer::new(PyNormalizerTypeWrapper::Sequence(vec![
Arc::new(RwLock::new(NFKC.into())),
Arc::new(RwLock::new(Lowercase.into())),
])));
let output = serde_pyo3::to_string(&tokenizer).unwrap();
assert_eq!(output, "");
}
}

View File

@ -16,8 +16,8 @@ use tokenizers as tk;
/// Trainer will return an instance of this class when instantiated.
#[pyclass(module = "tokenizers.trainers", name = "Trainer", subclass)]
#[derive(Clone, Deserialize, Serialize)]
#[serde(transparent)]
pub struct PyTrainer {
#[serde(flatten)]
pub trainer: Arc<RwLock<TrainerWrapper>>,
}