Make impl_serde_type support unit structs also

This commit is contained in:
Mishig Davaadorj
2022-01-25 17:57:22 +01:00
parent 1adcb63478
commit 3784a04fd4
7 changed files with 87 additions and 58 deletions

View File

@ -1,4 +1,5 @@
use crate::tokenizer::{NormalizedString, Normalizer, Result};
use crate::utils::macro_rules_attribute;
use serde::{Deserialize, Serialize};
use unicode_normalization_alignments::char::is_combining_mark;
@ -43,8 +44,8 @@ impl Normalizer for Strip {
// It's different from unidecode as it does not attempt to modify
// non ascii languages.
#[derive(Copy, Clone, Debug)]
#[macro_rules_attribute(impl_serde_type!)]
pub struct StripAccents;
impl_serde_unit_struct!(StripAccentsVisitor, StripAccents);
impl Normalizer for StripAccents {
/// Strip the normalized string inplace

View File

@ -1,6 +1,8 @@
use crate::tokenizer::{NormalizedString, Normalizer, Result};
use crate::utils::macro_rules_attribute;
#[derive(Default, Copy, Clone, Debug)]
#[macro_rules_attribute(impl_serde_type!)]
pub struct NFD;
impl Normalizer for NFD {
fn normalize(&self, normalized: &mut NormalizedString) -> Result<()> {
@ -10,6 +12,7 @@ impl Normalizer for NFD {
}
#[derive(Default, Copy, Clone, Debug)]
#[macro_rules_attribute(impl_serde_type!)]
pub struct NFKD;
impl Normalizer for NFKD {
fn normalize(&self, normalized: &mut NormalizedString) -> Result<()> {
@ -19,6 +22,7 @@ impl Normalizer for NFKD {
}
#[derive(Default, Copy, Clone, Debug)]
#[macro_rules_attribute(impl_serde_type!)]
pub struct NFC;
impl Normalizer for NFC {
fn normalize(&self, normalized: &mut NormalizedString) -> Result<()> {
@ -28,6 +32,7 @@ impl Normalizer for NFC {
}
#[derive(Default, Copy, Clone, Debug)]
#[macro_rules_attribute(impl_serde_type!)]
pub struct NFKC;
impl Normalizer for NFKC {
fn normalize(&self, normalized: &mut NormalizedString) -> Result<()> {
@ -68,6 +73,7 @@ fn do_nmt(normalized: &mut NormalizedString) {
}
#[derive(Default, Copy, Clone, Debug)]
#[macro_rules_attribute(impl_serde_type!)]
pub struct Nmt;
impl Normalizer for Nmt {
fn normalize(&self, normalized: &mut NormalizedString) -> Result<()> {
@ -76,12 +82,6 @@ impl Normalizer for Nmt {
}
}
impl_serde_unit_struct!(NFCVisitor, NFC);
impl_serde_unit_struct!(NFCKVisitor, NFKC);
impl_serde_unit_struct!(NFKDVisitor, NFKD);
impl_serde_unit_struct!(NFDVisitor, NFD);
impl_serde_unit_struct!(NMTVisitor, Nmt);
#[cfg(test)]
mod tests {
use super::*;

View File

@ -2,6 +2,7 @@ use serde::{Deserialize, Serialize};
use crate::normalizers::NormalizerWrapper;
use crate::tokenizer::{NormalizedString, Normalizer, Result};
use crate::utils::macro_rules_attribute;
#[derive(Clone, Deserialize, Debug, Serialize)]
#[serde(tag = "type")]
@ -36,6 +37,7 @@ impl Normalizer for Sequence {
/// Lowercases the input
#[derive(Copy, Clone, Debug)]
#[macro_rules_attribute(impl_serde_type!)]
pub struct Lowercase;
impl Normalizer for Lowercase {
fn normalize(&self, normalized: &mut NormalizedString) -> Result<()> {
@ -43,5 +45,3 @@ impl Normalizer for Lowercase {
Ok(())
}
}
impl_serde_unit_struct!(LowercaseVisitor, Lowercase);

View File

@ -1,4 +1,5 @@
use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior};
use crate::utils::macro_rules_attribute;
use unicode_categories::UnicodeCategories;
fn is_bert_punc(x: char) -> bool {
@ -6,8 +7,8 @@ fn is_bert_punc(x: char) -> bool {
}
#[derive(Copy, Clone, Debug, PartialEq)]
#[macro_rules_attribute(impl_serde_type!)]
pub struct BertPreTokenizer;
impl_serde_unit_struct!(BertVisitor, BertPreTokenizer);
impl PreTokenizer for BertPreTokenizer {
fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> Result<()> {

View File

@ -1,9 +1,10 @@
use crate::pre_tokenizers::unicode_scripts::scripts::{get_script, Script};
use crate::tokenizer::{normalizer::Range, PreTokenizedString, PreTokenizer, Result};
use crate::utils::macro_rules_attribute;
#[derive(Clone, Debug, PartialEq)]
#[macro_rules_attribute(impl_serde_type!)]
pub struct UnicodeScripts;
impl_serde_unit_struct!(UnicodeScriptsVisitor, UnicodeScripts);
impl UnicodeScripts {
pub fn new() -> Self {

View File

@ -3,10 +3,11 @@ use regex::Regex;
use crate::tokenizer::{
pattern::Invert, PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior,
};
use crate::utils::macro_rules_attribute;
#[derive(Clone, Debug, PartialEq)]
#[macro_rules_attribute(impl_serde_type!)]
pub struct Whitespace;
impl_serde_unit_struct!(WhitespaceVisitor, Whitespace);
impl Default for Whitespace {
fn default() -> Self {
@ -28,8 +29,8 @@ impl PreTokenizer for Whitespace {
}
#[derive(Copy, Clone, Debug, PartialEq)]
#[macro_rules_attribute(impl_serde_type!)]
pub struct WhitespaceSplit;
impl_serde_unit_struct!(WhitespaceSplitVisitor, WhitespaceSplit);
impl PreTokenizer for WhitespaceSplit {
fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> Result<()> {

View File

@ -33,48 +33,6 @@ macro_rules! impl_enum_from (
}
);
macro_rules! impl_serde_unit_struct (
($visitor:ident, $self_ty:tt) => {
impl serde::Serialize for $self_ty {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error> where
S: serde::ser::Serializer {
use serde::ser::SerializeStruct;
let self_ty_str = stringify!($self_ty);
let mut m = serializer.serialize_struct(self_ty_str,1)?;
m.serialize_field("type", self_ty_str)?;
m.end()
}
}
impl<'de> serde::Deserialize<'de> for $self_ty {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error> where
D: serde::de::Deserializer<'de> {
deserializer.deserialize_map($visitor)
}
}
struct $visitor;
impl<'de> serde::de::Visitor<'de> for $visitor {
type Value = $self_ty;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(formatter, stringify!($self_ty))
}
fn visit_map<A>(self, mut map: A) -> std::result::Result<Self::Value, A::Error> where
A: serde::de::MapAccess<'de>, {
let self_ty_str = stringify!($self_ty);
let maybe_type = map.next_entry::<String, String>()?;
let maybe_type_str = maybe_type.as_ref().map(|(k, v)| (k.as_str(), v.as_str()));
match maybe_type_str {
Some(("type", stringify!($self_ty))) => Ok($self_ty),
Some((_, ty)) => Err(serde::de::Error::custom(&format!("Expected {}, got {}", self_ty_str, ty))),
None => Err(serde::de::Error::custom(&format!("Expected type : {}", self_ty_str)))
}
}
}
}
);
/// Implement `serde::{Serialize, Serializer}` with `#[serde(tag = "type")]` attribute for a given struct.
/// Panic when a json string being deserilized misses field `type`.
///
@ -119,6 +77,37 @@ macro_rules! impl_serde_unit_struct (
/// let deserialized: Point1D = serde_json::from_str(serialized_s).unwrap();
/// }
/// ```
///
/// # Examples (unit structs)
///
/// ```
/// # #[macro_use] extern crate tokenizers;
/// use serde::{Serialize, Deserialize};
///
/// fn main() {
/// impl_serde_type!{
/// struct Unit;
/// }
///
/// let unit = Unit;
/// let serialized_s = r#"{"type":"Unit"}"#;
/// assert_eq!(serde_json::to_string(&unit).unwrap(), serialized_s);
/// }
/// ```
///
/// ```should_panic
/// # #[macro_use] extern crate tokenizers;
/// use serde::{Serialize, Deserialize};
///
/// fn main() {
/// impl_serde_type!{
/// struct Unit;
/// }
///
/// let serialized_s = r#"{"some_field":1}"#;
/// let deserialized: Unit = serde_json::from_str(serialized_s).unwrap();
/// }
/// ```
#[macro_export]
macro_rules! impl_serde_type{
(
@ -130,9 +119,7 @@ macro_rules! impl_serde_type{
),*$(,)+
}
) => {
use paste::paste;
paste!{
paste::paste!{
$(#[$meta])*
#[derive(Serialize, Deserialize)]
#[serde(tag = "type", from = $struct_name "Deserilaizer")]
@ -176,6 +163,44 @@ macro_rules! impl_serde_type{
}
}
}
};
(
$(#[$meta:meta])*
$vis:vis struct $struct_name:ident;
) => {
paste::paste!{
$(#[$meta])*
$vis struct $struct_name;
impl serde::Serialize for $struct_name {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error> where
S: serde::ser::Serializer {
let helper = [<$struct_name Helper>]{r#type: [<$struct_name Type>]::$struct_name};
helper.serialize(serializer)
}
}
impl<'de> serde::Deserialize<'de> for $struct_name {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let _helper = [<$struct_name Helper>]::deserialize(deserializer)?;
Ok($struct_name)
}
}
#[derive(serde::Serialize, serde::Deserialize)]
enum [<$struct_name Type>] {
$struct_name,
}
#[derive(serde::Serialize, serde::Deserialize)]
struct [<$struct_name Helper>] {
#[allow(dead_code)]
r#type: [<$struct_name Type>],
}
}
}
}