diff --git a/bindings/python/src/decoders.rs b/bindings/python/src/decoders.rs index 416ae3c2..c51774d8 100644 --- a/bindings/python/src/decoders.rs +++ b/bindings/python/src/decoders.rs @@ -301,7 +301,7 @@ mod test { match py_dec.decoder { PyDecoderWrapper::Wrapped(msp) => match msp.as_ref() { DecoderWrapper::Metaspace(_) => {} - _ => panic!("Expected Whitespace"), + _ => panic!("Expected Metaspace"), }, _ => panic!("Expected wrapped, not custom."), } diff --git a/bindings/python/src/models.rs b/bindings/python/src/models.rs index 7e7c5051..a320d6d6 100644 --- a/bindings/python/src/models.rs +++ b/bindings/python/src/models.rs @@ -1,4 +1,3 @@ -use std::collections::hash_map::RandomState; use std::collections::HashMap; use std::path::{Path, PathBuf}; use std::sync::Arc; @@ -10,11 +9,11 @@ use serde::{Deserialize, Serialize}; use tk::models::bpe::BPE; use tk::models::wordlevel::WordLevel; use tk::models::wordpiece::WordPiece; +use tk::models::ModelWrapper; use tk::{Model, Token}; use tokenizers as tk; use super::error::ToPyResult; -use tk::models::ModelWrapper; /// A Model represents some tokenization algorithm like BPE or Word /// This class cannot be constructed directly. Please use one of the concrete models. @@ -55,7 +54,7 @@ impl Model for PyModel { self.model.id_to_token(id) } - fn get_vocab(&self) -> &HashMap { + fn get_vocab(&self) -> &HashMap { self.model.get_vocab() } diff --git a/tokenizers/src/decoders/bpe.rs b/tokenizers/src/decoders/bpe.rs index 4f9de55f..80078fab 100644 --- a/tokenizers/src/decoders/bpe.rs +++ b/tokenizers/src/decoders/bpe.rs @@ -1,10 +1,11 @@ use crate::tokenizer::{Decoder, Result}; -use serde::ser::SerializeStruct; -use serde::{Deserialize, Serialize, Serializer}; -#[derive(Deserialize, Clone, Debug)] +use serde::{Deserialize, Serialize}; + +#[derive(Deserialize, Clone, Debug, Serialize)] /// Allows decoding Original BPE by joining all the tokens and then replacing /// the suffix used to identify end-of-words by whitespaces +#[serde(tag = "type")] pub struct BPEDecoder { suffix: String, } @@ -26,15 +27,3 @@ impl Decoder for BPEDecoder { Ok(tokens.join("").replace(&self.suffix, " ").trim().to_owned()) } } - -impl Serialize for BPEDecoder { - fn serialize(&self, serializer: S) -> std::result::Result - where - S: Serializer, - { - let mut m = serializer.serialize_struct("BPEDecoder", 2)?; - m.serialize_field("type", "BPEDecoder")?; - m.serialize_field("suffix", &self.suffix)?; - m.end() - } -} diff --git a/tokenizers/src/decoders/wordpiece.rs b/tokenizers/src/decoders/wordpiece.rs index 49ae1f64..3a4aeb23 100644 --- a/tokenizers/src/decoders/wordpiece.rs +++ b/tokenizers/src/decoders/wordpiece.rs @@ -1,10 +1,11 @@ use crate::tokenizer::{Decoder, Result}; -use serde::ser::SerializeStruct; -use serde::{Deserialize, Serialize, Serializer}; -#[derive(Deserialize, Clone, Debug)] +use serde::{Deserialize, Serialize}; + +#[derive(Deserialize, Clone, Debug, Serialize)] /// The WordPiece decoder takes care of decoding a list of wordpiece tokens /// back into a readable string. +#[serde(tag = "type")] pub struct WordPiece { /// The prefix to be used for continuing subwords prefix: String, @@ -48,16 +49,3 @@ impl Decoder for WordPiece { Ok(output) } } - -impl Serialize for WordPiece { - fn serialize(&self, serializer: S) -> std::result::Result - where - S: Serializer, - { - let mut m = serializer.serialize_struct("BPEDecoder", 3)?; - m.serialize_field("type", "BPEDecoder")?; - m.serialize_field("prefix", &self.prefix)?; - m.serialize_field("cleanup", &self.cleanup)?; - m.end() - } -} diff --git a/tokenizers/src/lib.rs b/tokenizers/src/lib.rs index 8613cca1..1cf4e0b9 100644 --- a/tokenizers/src/lib.rs +++ b/tokenizers/src/lib.rs @@ -29,7 +29,7 @@ //! use tokenizers::models::bpe::BPE; //! //! fn main() -> Result<()> { -//! let bpe_builder = BPE::from_files("./path/to/vocab.json", "./path/to/merges.txt"); +//! let bpe_builder = BPE::from_files("./path/to/vocab.json", "./path/to/merges.txt"); //! let bpe = bpe_builder //! .dropout(0.1) //! .unk_token("[UNK]".into()) diff --git a/tokenizers/src/normalizers/bert.rs b/tokenizers/src/normalizers/bert.rs index 5abea865..b5fc416d 100644 --- a/tokenizers/src/normalizers/bert.rs +++ b/tokenizers/src/normalizers/bert.rs @@ -1,6 +1,6 @@ use crate::tokenizer::{NormalizedString, Normalizer, Result}; -use serde::ser::SerializeStruct; -use serde::{Deserialize, Serialize, Serializer}; + +use serde::{Deserialize, Serialize}; use unicode_categories::UnicodeCategories; /// Checks whether a character is whitespace @@ -49,7 +49,8 @@ fn is_chinese_char(c: char) -> bool { } } -#[derive(Copy, Clone, Debug, Deserialize)] +#[derive(Copy, Clone, Debug, Deserialize, Serialize)] +#[serde(tag = "type")] pub struct BertNormalizer { /// Whether to do the bert basic cleaning: /// 1. Remove any control characters @@ -63,21 +64,6 @@ pub struct BertNormalizer { lowercase: bool, } -impl Serialize for BertNormalizer { - fn serialize(&self, serializer: S) -> std::result::Result - where - S: Serializer, - { - let mut m = serializer.serialize_struct("BertNormalizer", 5)?; - m.serialize_field("type", "BertNormalizer")?; - m.serialize_field("clean_text", &self.clean_text)?; - m.serialize_field("handle_chinese_chars", &self.handle_chinese_chars)?; - m.serialize_field("strip_accents", &self.strip_accents)?; - m.serialize_field("lowercase", &self.lowercase)?; - m.end() - } -} - impl Default for BertNormalizer { fn default() -> Self { Self { diff --git a/tokenizers/src/normalizers/strip.rs b/tokenizers/src/normalizers/strip.rs index 4c65196b..b7d0bfc7 100644 --- a/tokenizers/src/normalizers/strip.rs +++ b/tokenizers/src/normalizers/strip.rs @@ -1,26 +1,13 @@ use crate::tokenizer::{NormalizedString, Normalizer, Result}; -use serde::ser::SerializeStruct; -use serde::{Deserialize, Serialize, Serializer}; +use serde::{Deserialize, Serialize}; -#[derive(Copy, Clone, Debug, Deserialize)] +#[derive(Copy, Clone, Debug, Deserialize, Serialize)] +#[serde(tag = "type")] pub struct Strip { strip_left: bool, strip_right: bool, } -impl Serialize for Strip { - fn serialize(&self, serializer: S) -> std::result::Result - where - S: Serializer, - { - let mut m = serializer.serialize_struct("Strip", 5)?; - m.serialize_field("type", "BertNormalizer")?; - m.serialize_field("strip_left", &self.strip_left)?; - m.serialize_field("strip_right", &self.strip_right)?; - m.end() - } -} - impl Strip { pub fn new(strip_left: bool, strip_right: bool) -> Self { Strip { diff --git a/tokenizers/src/normalizers/utils.rs b/tokenizers/src/normalizers/utils.rs index 56136409..f4d4136b 100644 --- a/tokenizers/src/normalizers/utils.rs +++ b/tokenizers/src/normalizers/utils.rs @@ -1,27 +1,16 @@ +use serde::{Deserialize, Serialize}; + use crate::normalizers::NormalizerWrapper; use crate::tokenizer::{NormalizedString, Normalizer, Result}; -use serde::ser::SerializeStruct; -use serde::{Deserialize, Serialize, Serializer}; -#[derive(Clone, Deserialize, Debug)] +#[derive(Clone, Deserialize, Debug, Serialize)] +#[serde(tag = "type")] /// Allows concatenating multiple other Normalizer as a Sequence. /// All the normalizers run in sequence in the given order against the same NormalizedString. pub struct Sequence { normalizers: Vec, } -impl Serialize for Sequence { - fn serialize(&self, serializer: S) -> std::result::Result - where - S: Serializer, - { - let mut m = serializer.serialize_struct("Sequence", 2)?; - m.serialize_field("type", "Sequence")?; - m.serialize_field("normalizers", &self.normalizers)?; - m.end() - } -} - impl Sequence { pub fn new(normalizers: Vec) -> Self { Self { normalizers } diff --git a/tokenizers/src/pre_tokenizers/byte_level.rs b/tokenizers/src/pre_tokenizers/byte_level.rs index 928eb20b..fdaa2805 100644 --- a/tokenizers/src/pre_tokenizers/byte_level.rs +++ b/tokenizers/src/pre_tokenizers/byte_level.rs @@ -1,10 +1,11 @@ +use std::collections::{HashMap, HashSet}; + +use onig::Regex; +use serde::{Deserialize, Serialize}; + use crate::tokenizer::{ normalizer::Range, Decoder, Encoding, PostProcessor, PreTokenizedString, PreTokenizer, Result, }; -use onig::Regex; -use serde::ser::SerializeStruct; -use serde::{Deserialize, Serialize, Serializer}; -use std::collections::{HashMap, HashSet}; fn bytes_char() -> HashMap { let mut bs: Vec = vec![]; @@ -38,10 +39,11 @@ lazy_static! { bytes_char().into_iter().map(|(c, b)| (b, c)).collect(); } -#[derive(Deserialize, Copy, Clone, Debug)] +#[derive(Deserialize, Serialize, Copy, Clone, Debug)] /// Provides all the necessary steps to handle the BPE tokenization at the byte-level. Takes care /// of all the required processing steps to transform a UTF-8 string as needed before and after the /// BPE model does its job. +#[serde(tag = "type")] pub struct ByteLevel { /// Whether to add a leading space to the first word. This allows to treat the leading word /// just as any other word. @@ -212,19 +214,6 @@ pub fn process_offsets(encoding: &mut Encoding, add_prefix_space: bool) { }); } -impl Serialize for ByteLevel { - fn serialize(&self, serializer: S) -> std::result::Result - where - S: Serializer, - { - let mut m = serializer.serialize_struct("ByteLevel", 3)?; - m.serialize_field("type", "ByteLevel")?; - m.serialize_field("add_prefix_space", &self.add_prefix_space)?; - m.serialize_field("trim_offsets", &self.trim_offsets)?; - m.end() - } -} - #[cfg(test)] mod tests { use super::ByteLevel; diff --git a/tokenizers/src/pre_tokenizers/delimiter.rs b/tokenizers/src/pre_tokenizers/delimiter.rs index 54c82cde..ce486cf8 100644 --- a/tokenizers/src/pre_tokenizers/delimiter.rs +++ b/tokenizers/src/pre_tokenizers/delimiter.rs @@ -1,9 +1,9 @@ -use serde::ser::SerializeStruct; -use serde::{Deserialize, Serialize, Serializer}; +use serde::{Deserialize, Serialize}; use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior}; -#[derive(Copy, Clone, Debug, Deserialize)] +#[derive(Copy, Clone, Debug, Deserialize, Serialize)] +#[serde(tag = "type")] pub struct CharDelimiterSplit { delimiter: char, } @@ -22,15 +22,3 @@ impl PreTokenizer for CharDelimiterSplit { }) } } - -impl Serialize for CharDelimiterSplit { - fn serialize(&self, serializer: S) -> std::result::Result - where - S: Serializer, - { - let mut m = serializer.serialize_struct("CharDelimiterSplit", 2)?; - m.serialize_field("type", "CharDelimiterSplit")?; - m.serialize_field("delimiter", &self.delimiter)?; - m.end() - } -} diff --git a/tokenizers/src/pre_tokenizers/metaspace.rs b/tokenizers/src/pre_tokenizers/metaspace.rs index 1f722d63..da16a433 100644 --- a/tokenizers/src/pre_tokenizers/metaspace.rs +++ b/tokenizers/src/pre_tokenizers/metaspace.rs @@ -1,11 +1,11 @@ -use serde::ser::SerializeStruct; -use serde::{Deserialize, Serialize, Serializer}; +use serde::{Deserialize, Serialize}; use crate::tokenizer::{Decoder, PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior}; -#[derive(Deserialize, Clone, Debug)] +#[derive(Serialize, Deserialize, Clone, Debug)] /// Replaces all the whitespaces by the provided meta character and then /// splits on this character +#[serde(tag = "type")] pub struct Metaspace { replacement: char, str_rep: String, @@ -68,20 +68,6 @@ impl Decoder for Metaspace { } } -impl Serialize for Metaspace { - fn serialize(&self, serializer: S) -> std::result::Result - where - S: Serializer, - { - let mut m = serializer.serialize_struct("Metaspace", 3)?; - m.serialize_field("type", "Metaspace")?; - m.serialize_field("replacement", &self.replacement)?; - m.serialize_field("str_rep", &self.str_rep)?; - m.serialize_field("add_prefix_space", &self.add_prefix_space)?; - m.end() - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/tokenizers/src/pre_tokenizers/whitespace.rs b/tokenizers/src/pre_tokenizers/whitespace.rs index a8ebc567..f7d631f2 100644 --- a/tokenizers/src/pre_tokenizers/whitespace.rs +++ b/tokenizers/src/pre_tokenizers/whitespace.rs @@ -1,16 +1,17 @@ use std::fmt; use regex::Regex; -use serde::de::{Error, Visitor}; -use serde::ser::SerializeStruct; -use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use serde::{Deserialize, Deserializer, Serialize}; use crate::tokenizer::{ pattern::Invert, PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior, }; +use serde::de::{Error, Visitor}; -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Serialize)] +#[serde(tag = "type")] pub struct Whitespace { + #[serde(default = "default_regex", skip)] re: Regex, } @@ -34,19 +35,8 @@ impl PreTokenizer for Whitespace { } } -// manually implement serialize / deserialize because Whitespace is not a unit-struct but is +// manually implement deserialize because Whitespace is not a unit-struct but is // serialized like one. -impl Serialize for Whitespace { - fn serialize(&self, serializer: S) -> std::result::Result - where - S: Serializer, - { - let mut m = serializer.serialize_struct("Whitespace", 1)?; - m.serialize_field("type", "Whitespace")?; - m.end() - } -} - impl<'de> Deserialize<'de> for Whitespace { fn deserialize(deserializer: D) -> std::result::Result where diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index d2fe83c0..03dfcf49 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -657,10 +657,6 @@ where /// ``` /// # use tokenizers::Tokenizer; /// # use tokenizers::models::bpe::BPE; - /// # use tokenizers::normalizers::NormalizerWrapper; - /// # use tokenizers::pre_tokenizers::PreTokenizerWrapper; - /// # use tokenizers::processors::PostProcessorWrapper; - /// # use tokenizers::decoders::DecoderWrapper; /// # let mut tokenizer = Tokenizer::new(BPE::default()); /// # /// // Sequences: diff --git a/tokenizers/tests/serialization.rs b/tokenizers/tests/serialization.rs index 87203cc7..1c4f18d3 100644 --- a/tokenizers/tests/serialization.rs +++ b/tokenizers/tests/serialization.rs @@ -125,6 +125,15 @@ fn pretoks() { // wrapped serializes same way as inner let ser_wrapped = serde_json::to_string(&ch_wrapped).unwrap(); assert_eq!(ser_wrapped, ch_ser); + + let wsp = Whitespace::default(); + let wsp_ser = serde_json::to_string(&wsp).unwrap(); + serde_json::from_str::(&wsp_ser).unwrap(); + let err: Result = serde_json::from_str(&wsp_ser); + assert!( + err.is_err(), + "BertPreTokenizer shouldn't be deserializable from Whitespace" + ); } #[test]