From 9a9c70563a8eb8fc89d48a0c2d84fac3bed29726 Mon Sep 17 00:00:00 2001 From: Mishig Davaadorj Date: Sat, 22 Jan 2022 22:53:02 +0100 Subject: [PATCH] Implement `impl_serde_type` macro --- tokenizers/Cargo.toml | 2 + tokenizers/src/pre_tokenizers/byte_level.rs | 30 +-- tokenizers/src/pre_tokenizers/delimiter.rs | 29 +-- tokenizers/src/pre_tokenizers/digits.rs | 29 +-- tokenizers/src/pre_tokenizers/metaspace.rs | 10 +- tokenizers/src/pre_tokenizers/punctuation.rs | 31 +--- tokenizers/src/pre_tokenizers/sequence.rs | 29 +-- tokenizers/src/pre_tokenizers/split.rs | 6 +- tokenizers/src/utils/mod.rs | 107 +++++++++++ tokenizers/src/utils/proc_macros/Cargo.toml | 11 ++ tokenizers/src/utils/proc_macros/README.md | 1 + tokenizers/src/utils/proc_macros/src/lib.rs | 181 +++++++++++++++++++ 12 files changed, 330 insertions(+), 136 deletions(-) create mode 100644 tokenizers/src/utils/proc_macros/Cargo.toml create mode 100644 tokenizers/src/utils/proc_macros/README.md create mode 100644 tokenizers/src/utils/proc_macros/src/lib.rs diff --git a/tokenizers/Cargo.toml b/tokenizers/Cargo.toml index 691f6cbf..1ca9ba0b 100644 --- a/tokenizers/Cargo.toml +++ b/tokenizers/Cargo.toml @@ -57,6 +57,8 @@ dirs = "3.0" reqwest = { version = "0.11", optional = true } cached-path = { version = "0.5", optional = true } aho-corasick = "0.7" +paste = "1.0.6" +proc_macros = { path = "./src/utils/proc_macros" } [features] default = ["progressbar", "http"] diff --git a/tokenizers/src/pre_tokenizers/byte_level.rs b/tokenizers/src/pre_tokenizers/byte_level.rs index 92b8c13a..2a8a811e 100644 --- a/tokenizers/src/pre_tokenizers/byte_level.rs +++ b/tokenizers/src/pre_tokenizers/byte_level.rs @@ -1,12 +1,13 @@ use std::collections::{HashMap, HashSet}; use onig::Regex; -use serde::{Deserialize, Deserializer, Serialize}; +use serde::{Deserialize, Serialize}; use crate::tokenizer::{ Decoder, Encoding, PostProcessor, PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior, }; +use crate::utils::macro_rules_attribute; fn bytes_char() -> HashMap { let mut bs: Vec = vec![]; @@ -40,11 +41,11 @@ lazy_static! { bytes_char().into_iter().map(|(c, b)| (b, c)).collect(); } -#[derive(Serialize, Copy, Clone, Debug, PartialEq)] +#[derive(Copy, Clone, Debug, PartialEq)] /// Provides all the necessary steps to handle the BPE tokenization at the byte-level. Takes care /// of all the required processing steps to transform a UTF-8 string as needed before and after the /// BPE model does its job. -#[serde(tag = "type")] +#[macro_rules_attribute(impl_serde_type!)] #[non_exhaustive] pub struct ByteLevel { /// Whether to add a leading space to the first word. This allows to treat the leading word @@ -54,29 +55,6 @@ pub struct ByteLevel { pub trim_offsets: bool, } -impl<'de> Deserialize<'de> for ByteLevel { - fn deserialize(deserializer: D) -> std::result::Result - where - D: Deserializer<'de>, - { - #[derive(Deserialize)] - enum Type { - ByteLevel, - } - - #[derive(Deserialize)] - pub struct ByteLevelHelper { - #[serde(rename = "type")] - _type: Type, - add_prefix_space: bool, - trim_offsets: bool, - } - - let helper = ByteLevelHelper::deserialize(deserializer)?; - Ok(ByteLevel::new(helper.add_prefix_space, helper.trim_offsets)) - } -} - impl Default for ByteLevel { fn default() -> Self { Self { diff --git a/tokenizers/src/pre_tokenizers/delimiter.rs b/tokenizers/src/pre_tokenizers/delimiter.rs index 075a48e8..45935b52 100644 --- a/tokenizers/src/pre_tokenizers/delimiter.rs +++ b/tokenizers/src/pre_tokenizers/delimiter.rs @@ -1,36 +1,15 @@ -use serde::{Deserialize, Deserializer, Serialize}; +use serde::{Deserialize, Serialize}; use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior}; +use crate::utils::macro_rules_attribute; -#[derive(Copy, Clone, Debug, Serialize, PartialEq)] -#[serde(tag = "type")] +#[derive(Copy, Clone, Debug, PartialEq)] #[non_exhaustive] +#[macro_rules_attribute(impl_serde_type!)] pub struct CharDelimiterSplit { pub delimiter: char, } -impl<'de> Deserialize<'de> for CharDelimiterSplit { - fn deserialize(deserializer: D) -> std::result::Result - where - D: Deserializer<'de>, - { - #[derive(Deserialize)] - enum Type { - CharDelimiterSplit, - } - - #[derive(Deserialize)] - pub struct CharDelimiterSplitHelper { - #[serde(rename = "type")] - _type: Type, - delimiter: char, - } - - let helper = CharDelimiterSplitHelper::deserialize(deserializer)?; - Ok(CharDelimiterSplit::new(helper.delimiter)) - } -} - impl CharDelimiterSplit { pub fn new(delimiter: char) -> Self { CharDelimiterSplit { delimiter } diff --git a/tokenizers/src/pre_tokenizers/digits.rs b/tokenizers/src/pre_tokenizers/digits.rs index 02f977c2..a64bab7c 100644 --- a/tokenizers/src/pre_tokenizers/digits.rs +++ b/tokenizers/src/pre_tokenizers/digits.rs @@ -1,38 +1,17 @@ -use serde::{Deserialize, Deserializer, Serialize}; +use serde::{Deserialize, Serialize}; use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior}; +use crate::utils::macro_rules_attribute; -#[derive(Serialize, Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq)] /// Pre tokenizes the numbers into single tokens. If individual_digits is set /// to true, then all digits are splitted into individual tokens. -#[serde(tag = "type")] #[non_exhaustive] +#[macro_rules_attribute(impl_serde_type!)] pub struct Digits { pub individual_digits: bool, } -impl<'de> Deserialize<'de> for Digits { - fn deserialize(deserializer: D) -> std::result::Result - where - D: Deserializer<'de>, - { - #[derive(Deserialize)] - enum Type { - Digits, - } - - #[derive(Deserialize)] - pub struct DigitsHelper { - #[serde(rename = "type")] - _type: Type, - individual_digits: bool, - } - - let helper = DigitsHelper::deserialize(deserializer)?; - Ok(Digits::new(helper.individual_digits)) - } -} - impl Digits { pub fn new(individual_digits: bool) -> Self { Self { individual_digits } diff --git a/tokenizers/src/pre_tokenizers/metaspace.rs b/tokenizers/src/pre_tokenizers/metaspace.rs index 22f619cc..78fe4171 100644 --- a/tokenizers/src/pre_tokenizers/metaspace.rs +++ b/tokenizers/src/pre_tokenizers/metaspace.rs @@ -24,13 +24,11 @@ impl<'de> Deserialize<'de> for Metaspace { } #[derive(Deserialize)] - pub struct MetaspaceHelper { - #[serde(rename = "type")] - _type: Type, + struct MetaspaceHelper { + #[allow(dead_code)] + r#type: Type, replacement: char, - pub add_prefix_space: bool, - #[serde(skip, rename = "str_rep")] - _str_rep: String, + add_prefix_space: bool, } let helper = MetaspaceHelper::deserialize(deserializer)?; diff --git a/tokenizers/src/pre_tokenizers/punctuation.rs b/tokenizers/src/pre_tokenizers/punctuation.rs index 46341ac9..43e9bc28 100644 --- a/tokenizers/src/pre_tokenizers/punctuation.rs +++ b/tokenizers/src/pre_tokenizers/punctuation.rs @@ -1,41 +1,20 @@ -use serde::{Deserialize, Deserializer, Serialize}; +use serde::{Deserialize, Serialize}; use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior}; +use crate::utils::macro_rules_attribute; use unicode_categories::UnicodeCategories; fn is_punc(x: char) -> bool { char::is_ascii_punctuation(&x) || x.is_punctuation() } -#[derive(Serialize, Copy, Clone, Debug, PartialEq)] -#[serde(tag = "type")] +#[derive(Copy, Clone, Debug, PartialEq)] +#[macro_rules_attribute(impl_serde_type!)] pub struct Punctuation { + #[serde(default = "default_split")] behavior: SplitDelimiterBehavior, } -impl<'de> Deserialize<'de> for Punctuation { - fn deserialize(deserializer: D) -> std::result::Result - where - D: Deserializer<'de>, - { - #[derive(Deserialize)] - enum Type { - Punctuation, - } - - #[derive(Deserialize)] - pub struct PunctuationHelper { - #[serde(rename = "type")] - _type: Type, - #[serde(default = "default_split")] - behavior: SplitDelimiterBehavior, - } - - let helper = PunctuationHelper::deserialize(deserializer)?; - Ok(Punctuation::new(helper.behavior)) - } -} - fn default_split() -> SplitDelimiterBehavior { SplitDelimiterBehavior::Isolated } diff --git a/tokenizers/src/pre_tokenizers/sequence.rs b/tokenizers/src/pre_tokenizers/sequence.rs index 5d722471..6175cc2e 100644 --- a/tokenizers/src/pre_tokenizers/sequence.rs +++ b/tokenizers/src/pre_tokenizers/sequence.rs @@ -1,35 +1,14 @@ use crate::pre_tokenizers::PreTokenizerWrapper; use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result}; -use serde::{Deserialize, Deserializer, Serialize}; +use crate::utils::macro_rules_attribute; +use serde::{Deserialize, Serialize}; -#[derive(Clone, Debug, Serialize, PartialEq)] -#[serde(tag = "type")] +#[derive(Clone, Debug, PartialEq)] +#[macro_rules_attribute(impl_serde_type!)] pub struct Sequence { pretokenizers: Vec, } -impl<'de> Deserialize<'de> for Sequence { - fn deserialize(deserializer: D) -> std::result::Result - where - D: Deserializer<'de>, - { - #[derive(Deserialize)] - enum Type { - Sequence, - } - - #[derive(Deserialize)] - pub struct SequenceHelper { - #[serde(rename = "type")] - _type: Type, - pretokenizers: Vec, - } - - let helper = SequenceHelper::deserialize(deserializer)?; - Ok(Sequence::new(helper.pretokenizers)) - } -} - impl Sequence { pub fn new(pretokenizers: Vec) -> Self { Self { pretokenizers } diff --git a/tokenizers/src/pre_tokenizers/split.rs b/tokenizers/src/pre_tokenizers/split.rs index 32ea25a8..2cff82bc 100644 --- a/tokenizers/src/pre_tokenizers/split.rs +++ b/tokenizers/src/pre_tokenizers/split.rs @@ -45,9 +45,9 @@ impl<'de> Deserialize<'de> for Split { } #[derive(Deserialize)] - pub struct SplitHelper { - #[serde(rename = "type")] - _type: Type, + struct SplitHelper { + #[allow(dead_code)] + r#type: Type, pattern: SplitPattern, behavior: SplitDelimiterBehavior, invert: bool, diff --git a/tokenizers/src/utils/mod.rs b/tokenizers/src/utils/mod.rs index c49f10be..f6830e1a 100644 --- a/tokenizers/src/utils/mod.rs +++ b/tokenizers/src/utils/mod.rs @@ -74,3 +74,110 @@ macro_rules! impl_serde_unit_struct ( } } ); + +/// Implement `serde::{Serialize, Serializer}` with `#[serde(tag = "type")]` attribute for a given struct. +/// Panic when a json string being deserilized misses field `type`. +/// +/// # Examples +/// +/// ``` +/// # #[macro_use] extern crate tokenizers; +/// use serde::{Serialize, Deserialize}; +/// +/// fn main() { +/// impl_serde_type!{ +/// #[derive(Debug)] +/// struct Point { +/// x: i32, +/// #[serde(default = "default_y")] +/// y: i32, +/// } +/// } +/// fn default_y() -> i32 { +/// 5 +/// } +/// +/// let point = Point { x: 1, y: 2 }; +/// let serialized_s = r#"{"type":"Point","x":1,"y":2}"#; +/// assert_eq!(serde_json::to_string(&point).unwrap(), serialized_s); +/// } +/// ``` +/// +/// ```should_panic +/// # #[macro_use] extern crate tokenizers; +/// use serde::{Serialize, Deserialize}; +/// +/// fn main() { +/// impl_serde_type!{ +/// #[derive(Debug)] +/// struct Point1D { +/// x: i32, +/// } +/// } +/// +/// let serialized_s = r#"{"x":1}"#; +/// let deserialized: Point1D = serde_json::from_str(serialized_s).unwrap(); +/// } +/// ``` +#[macro_export] +macro_rules! impl_serde_type{ + ( + $(#[$meta:meta])* + $vis:vis struct $struct_name:ident { + $( + $(#[$field_meta:meta])* + $field_vis:vis $field_name:ident : $field_type:ty + ),*$(,)+ + } + ) => { + use paste::paste; + + paste!{ + $(#[$meta])* + #[derive(Serialize, Deserialize)] + #[serde(tag = "type", from = $struct_name "Deserilaizer")] + $vis struct $struct_name{ + $( + $(#[$field_meta])* + $field_vis $field_name : $field_type, + )* + } + + #[doc(hidden)] + $(#[$meta])* + #[derive(Deserialize)] + #[serde(tag = "type", remote = $struct_name "")] + struct [<$struct_name Def>]{ + $( + $(#[$field_meta])* + $field_vis $field_name : $field_type, + )* + } + + #[doc(hidden)] + #[derive(Deserialize)] + enum [<$struct_name Type>] { + $struct_name, + } + + #[doc(hidden)] + #[derive(Deserialize)] + struct [<$struct_name Deserilaizer>] { + #[allow(dead_code)] + r#type: [<$struct_name Type>], + #[serde(flatten, with = $struct_name "Def")] + r#struct: $struct_name, + } + + #[doc(hidden)] + impl std::convert::From<[<$struct_name Deserilaizer>]> for $struct_name { + fn from(v: [<$struct_name Deserilaizer>]) -> Self { + v.r#struct + } + } + } + } +} + +// Re-export macro_rules_attribute +pub use proc_macros::macro_rules_attribute; diff --git a/tokenizers/src/utils/proc_macros/Cargo.toml b/tokenizers/src/utils/proc_macros/Cargo.toml new file mode 100644 index 00000000..9e6fe06a --- /dev/null +++ b/tokenizers/src/utils/proc_macros/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "proc_macros" +version = "0.1.0" +edition = "2018" + +[lib] +proc-macro = true + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] diff --git a/tokenizers/src/utils/proc_macros/README.md b/tokenizers/src/utils/proc_macros/README.md new file mode 100644 index 00000000..f50523f1 --- /dev/null +++ b/tokenizers/src/utils/proc_macros/README.md @@ -0,0 +1 @@ +Todod here \ No newline at end of file diff --git a/tokenizers/src/utils/proc_macros/src/lib.rs b/tokenizers/src/utils/proc_macros/src/lib.rs new file mode 100644 index 00000000..6c4b189e --- /dev/null +++ b/tokenizers/src/utils/proc_macros/src/lib.rs @@ -0,0 +1,181 @@ +//! Do not use this crate directly. Instead, use [`::macro_rules_attribute`]( +//! https://docs.rs/macro_rules_attribute) + +extern crate proc_macro; +use ::proc_macro::*; + +/// Applies the given `macro_rules!` macro to the decorated item. +/// +/// This, as with any `proc_macro_attribute`, **consumes** the item it +/// decorates: it is the `macro_rules!` macro job to generate it (_it is thus +/// able to modify it_!). +/// +/// For a version with "read-only" access to the item it decorates, see +/// [`macro_rules_derive`][`macro@macro_rules_derive`]. +/// +/// # Example +/// +/// Deriving getters for a (non-generic) `struct`: +/// +/// ```rust +/// # macro_rules! ignore {($($tt:tt)*) => () } +/// # ignore! { +/// #[macro_use] +/// extern crate macro_rules_attribute; +/// # } +/// +/// macro_rules! make_getters {( +/// $(#[$struct_meta:meta])* +/// $struct_vis:vis +/// struct $StructName:ident { +/// $( +/// $(#[$field_meta:meta])* +/// $field_vis:vis // this visibility will be applied to the getters instead +/// $field_name:ident : $field_ty:ty +/// ),* $(,)? +/// } +/// ) => ( +/// // First, generate the struct definition we have been given, but with +/// // private fields instead. +/// $(#[$struct_meta])* +/// $struct_vis +/// struct $StructName { +/// $( +/// $(#[$field_meta])* +/// // notice the lack of visibility => private fields +/// $field_name: $field_ty, +/// )* +/// } +/// +/// // Then, implement the getters: +/// impl $StructName { +/// $( +/// #[inline] +/// $field_vis +/// fn $field_name (self: &'_ Self) +/// -> &'_ $field_ty +/// { +/// &self.$field_name +/// } +/// )* +/// } +/// )} +/// +/// mod example { +/// # use ::macro_rules_attribute_proc_macro::macro_rules_attribute; +/// #[macro_rules_attribute(make_getters!)] +/// /// The macro handles meta attributes such as docstrings +/// pub +/// struct Person { +/// pub +/// name: String, +/// +/// pub +/// age: u8, +/// } +/// } +/// use example::Person; +/// +/// fn is_new_born (person: &'_ Person) +/// -> bool +/// { +/// // person.age == 0 +/// // ^ error[E0616]: field `age` of struct `example::Person` is private +/// *person.age() == 0 +/// } +/// ``` +#[proc_macro_attribute] pub +fn macro_rules_attribute ( + attrs: TokenStream, + input: TokenStream, +) -> TokenStream +{ + // check that `attrs` is indeed of the form `$macro_name:path !` + { + // FIXME: do this properly + match attrs.clone().into_iter().last() { + | Some(TokenTree::Punct(ref punct)) + if punct.as_char() == '!' + => {}, + + | _ => { + panic!("Expected a parameter of the form `macro_name !`"); + }, + } + } + let mut ret = attrs; + ret.extend(::std::iter::once( + TokenTree::Group(Group::new( + Delimiter::Brace, + // FIXME: directly using `input` makes the token stream be seen + // as a single token tree by the declarative macro !?? + input.into_iter().collect(), + )) + )); + #[cfg(feature = "verbose-expansions")] + eprintln!("{}", ret); + ret +} + +/// Applies the given `macro_rules!` macro to the decorated item. +/// +/// This, as with any `#[derive(...)]`, **does not consume** the item it +/// decorates: instead, it only generates code on top of it. +/// +/// # Example +/// +/// Implementing `Into` for a given `#[repr(Int)]` `enum`: +/// +/// ```rust +/// # macro_rules! ignore {($($tt:tt)*) => () } +/// # ignore! { +/// #[macro_use] +/// extern crate macro_rules_attribute; +/// # } +/// +/// macro_rules! ToInteger {( +/// #[repr($Int:ident)] +/// $(#[$enum_meta:meta])* +/// $pub:vis +/// enum $Enum:ident { +/// $( +/// $Variant:ident $(= $value:expr)? +/// ),* $(,)? +/// } +/// ) => ( +/// impl ::core::convert::From<$Enum> for $Int { +/// #[inline] +/// fn from (x: $Enum) +/// -> Self +/// { +/// x as _ +/// } +/// } +/// )} +/// +/// # use ::macro_rules_attribute_proc_macro::macro_rules_derive; +/// #[macro_rules_derive(ToInteger!)] +/// #[repr(u32)] +/// enum Bool { +/// False, +/// True, +/// } +/// +/// fn main () +/// { +/// assert_eq!(u32::from(Bool::False), 0); +/// assert_eq!(u32::from(Bool::True), 1); +/// // assert_eq!(u8::from(Bool::False), 0); +/// // ^ error[E0277]: the trait bound `u8: std::convert::From` is not satisfied +/// } +/// ``` +#[proc_macro_attribute] pub +fn macro_rules_derive ( + attrs: TokenStream, + input: TokenStream, +) -> TokenStream +{ + let mut ret = input.clone(); + ret.extend(macro_rules_attribute(attrs, input)); + ret +}