diff --git a/tokenizers/src/pre_tokenizers/whitespace.rs b/tokenizers/src/pre_tokenizers/whitespace.rs index 2a2a0772..c7622811 100644 --- a/tokenizers/src/pre_tokenizers/whitespace.rs +++ b/tokenizers/src/pre_tokenizers/whitespace.rs @@ -1,68 +1,29 @@ -use std::fmt; - use regex::Regex; -use serde::{Deserialize, Deserializer, Serialize}; use crate::tokenizer::{ pattern::Invert, PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior, }; -use serde::de::{Error, Visitor}; -#[derive(Clone, Debug, Serialize)] -#[serde(tag = "type")] -pub struct Whitespace { - #[serde(default = "default_regex", skip)] - re: Regex, -} - -fn default_regex() -> Regex { - Regex::new(r"\w+|[^\w\s]+").unwrap() -} +#[derive(Clone, Debug)] +pub struct Whitespace; +impl_serde_unit_struct!(WhitespaceVisitor, Whitespace); impl Default for Whitespace { fn default() -> Self { - Self { - re: default_regex(), - } + Self } } impl PreTokenizer for Whitespace { fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> Result<()> { - pretokenized.split(|_, normalized| { - normalized.split(Invert(&self.re), SplitDelimiterBehavior::Removed) - }) - } -} - -// manually implement deserialize because Whitespace is not a unit-struct but is -// serialized like one. -impl<'de> Deserialize<'de> for Whitespace { - fn deserialize(deserializer: D) -> std::result::Result - where - D: Deserializer<'de>, - { - deserializer.deserialize_map(WhitespaceVisitor) - } -} -struct WhitespaceVisitor; -impl<'de> Visitor<'de> for WhitespaceVisitor { - type Value = Whitespace; - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - write!(formatter, "Whitespace") - } - - fn visit_map(self, mut map: A) -> std::result::Result - where - A: serde::de::MapAccess<'de>, - { - let maybe_type = map.next_entry::()?; - let maybe_type_str = maybe_type.as_ref().map(|(k, v)| (k.as_str(), v.as_str())); - match maybe_type_str { - Some(("type", "Whitespace")) => Ok(Whitespace::default()), - Some((_, ty)) => Err(Error::custom(&format!("Expected Whitespace, got {}", ty))), - None => Err(Error::custom("Expected type: Whitespace")), + lazy_static! { + static ref RE: Regex = Regex::new(r"\w+|[^\w\s]+").unwrap(); } + let re_ref: &Regex = &RE; + + pretokenized.split(|_, normalized| { + normalized.split(Invert(re_ref), SplitDelimiterBehavior::Removed) + }) } }