From 86138337fc8f072254cc787e8b4d3704aaa89ce4 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 2 Aug 2024 18:41:54 +0200 Subject: [PATCH] Using serde (serde_pyo3) to get __str__ and __repr__ easily. --- bindings/python/Cargo.toml | 1 + bindings/python/src/decoders.rs | 2 +- bindings/python/src/models.rs | 2 +- bindings/python/src/normalizers.rs | 2 +- bindings/python/src/pre_tokenizers.rs | 2 +- bindings/python/src/processors.rs | 2 +- bindings/python/src/tokenizer.rs | 21 ++++++++++++++++++++- bindings/python/src/trainers.rs | 2 +- 8 files changed, 27 insertions(+), 7 deletions(-) diff --git a/bindings/python/Cargo.toml b/bindings/python/Cargo.toml index b494e408..8a81ac3d 100644 --- a/bindings/python/Cargo.toml +++ b/bindings/python/Cargo.toml @@ -18,6 +18,7 @@ pyo3 = { version = "0.21" } numpy = "0.21" ndarray = "0.15" itertools = "0.12" +serde_pyo3 = { git = "https://github.com/Narsil/serde_pyo3" } [dependencies.tokenizers] path = "../../tokenizers" diff --git a/bindings/python/src/decoders.rs b/bindings/python/src/decoders.rs index ed21f346..4a4af94d 100644 --- a/bindings/python/src/decoders.rs +++ b/bindings/python/src/decoders.rs @@ -29,8 +29,8 @@ use super::error::ToPyResult; /// a Decoder will return an instance of this class when instantiated. #[pyclass(dict, module = "tokenizers.decoders", name = "Decoder", subclass)] #[derive(Clone, Deserialize, Serialize)] +#[serde(transparent)] pub struct PyDecoder { - #[serde(flatten)] pub(crate) decoder: PyDecoderWrapper, } diff --git a/bindings/python/src/models.rs b/bindings/python/src/models.rs index bffa1bc2..2bfaafd3 100644 --- a/bindings/python/src/models.rs +++ b/bindings/python/src/models.rs @@ -26,8 +26,8 @@ use super::error::{deprecation_warning, ToPyResult}; /// This class cannot be constructed directly. Please use one of the concrete models. #[pyclass(module = "tokenizers.models", name = "Model", subclass)] #[derive(Clone, Serialize, Deserialize)] +#[serde(transparent)] pub struct PyModel { - #[serde(flatten)] pub model: Arc>, } diff --git a/bindings/python/src/normalizers.rs b/bindings/python/src/normalizers.rs index 864947e3..724e79b8 100644 --- a/bindings/python/src/normalizers.rs +++ b/bindings/python/src/normalizers.rs @@ -44,8 +44,8 @@ impl PyNormalizedStringMut<'_> { /// Normalizer will return an instance of this class when instantiated. #[pyclass(dict, module = "tokenizers.normalizers", name = "Normalizer", subclass)] #[derive(Clone, Serialize, Deserialize)] +#[serde(transparent)] pub struct PyNormalizer { - #[serde(flatten)] pub(crate) normalizer: PyNormalizerTypeWrapper, } diff --git a/bindings/python/src/pre_tokenizers.rs b/bindings/python/src/pre_tokenizers.rs index a2bd9b39..a9060ec3 100644 --- a/bindings/python/src/pre_tokenizers.rs +++ b/bindings/python/src/pre_tokenizers.rs @@ -35,8 +35,8 @@ use super::utils::*; subclass )] #[derive(Clone, Serialize, Deserialize)] +#[serde(transparent)] pub struct PyPreTokenizer { - #[serde(flatten)] pub(crate) pretok: PyPreTokenizerTypeWrapper, } diff --git a/bindings/python/src/processors.rs b/bindings/python/src/processors.rs index c46d8ea4..aceb1d44 100644 --- a/bindings/python/src/processors.rs +++ b/bindings/python/src/processors.rs @@ -28,8 +28,8 @@ use tokenizers as tk; subclass )] #[derive(Clone, Deserialize, Serialize)] +#[serde(transparent)] pub struct PyPostProcessor { - #[serde(flatten)] pub processor: Arc, } diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index 1c6bc9cc..5bc57f77 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -1,3 +1,4 @@ +use serde::Serialize; use std::collections::{hash_map::DefaultHasher, HashMap}; use std::hash::{Hash, Hasher}; @@ -462,7 +463,8 @@ type Tokenizer = TokenizerImpl PyResult { + serde_pyo3::to_string(self).map_err(|e| exceptions::PyException::new_err(e.to_string())) + } + /// Return the number of special tokens that would be added for single/pair sentences. /// :param is_pair: Boolean indicating if the input would be a single sentence or a pair /// :return: @@ -1434,4 +1441,16 @@ mod test { Tokenizer::from_file(&tmp).unwrap(); } + + #[test] + fn serde_pyo3() { + let mut tokenizer = Tokenizer::new(PyModel::from(BPE::default())); + tokenizer.with_normalizer(PyNormalizer::new(PyNormalizerTypeWrapper::Sequence(vec![ + Arc::new(RwLock::new(NFKC.into())), + Arc::new(RwLock::new(Lowercase.into())), + ]))); + + let output = serde_pyo3::to_string(&tokenizer).unwrap(); + assert_eq!(output, ""); + } } diff --git a/bindings/python/src/trainers.rs b/bindings/python/src/trainers.rs index 716e4cfe..cbce2aef 100644 --- a/bindings/python/src/trainers.rs +++ b/bindings/python/src/trainers.rs @@ -16,8 +16,8 @@ use tokenizers as tk; /// Trainer will return an instance of this class when instantiated. #[pyclass(module = "tokenizers.trainers", name = "Trainer", subclass)] #[derive(Clone, Deserialize, Serialize)] +#[serde(transparent)] pub struct PyTrainer { - #[serde(flatten)] pub trainer: Arc>, }