mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
add deserialize for pre tokenizers (#1603)
* add deserialize * copy from the decoder * fmt * clippy * fix rust tests * fmt * don't change the test
This commit is contained in:
@ -9,7 +9,7 @@ pub mod split;
|
||||
pub mod unicode_scripts;
|
||||
pub mod whitespace;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde::{Deserialize, Deserializer, Serialize};
|
||||
|
||||
use crate::pre_tokenizers::bert::BertPreTokenizer;
|
||||
use crate::pre_tokenizers::byte_level::ByteLevel;
|
||||
@ -23,7 +23,7 @@ use crate::pre_tokenizers::unicode_scripts::UnicodeScripts;
|
||||
use crate::pre_tokenizers::whitespace::{Whitespace, WhitespaceSplit};
|
||||
use crate::{PreTokenizedString, PreTokenizer};
|
||||
|
||||
#[derive(Deserialize, Serialize, Clone, Debug, PartialEq)]
|
||||
#[derive(Serialize, Clone, Debug, PartialEq)]
|
||||
#[serde(untagged)]
|
||||
pub enum PreTokenizerWrapper {
|
||||
BertPreTokenizer(BertPreTokenizer),
|
||||
@ -57,6 +57,142 @@ impl PreTokenizer for PreTokenizerWrapper {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for PreTokenizerWrapper {
|
||||
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
#[derive(Deserialize)]
|
||||
pub struct Tagged {
|
||||
#[serde(rename = "type")]
|
||||
variant: EnumType,
|
||||
#[serde(flatten)]
|
||||
rest: serde_json::Value,
|
||||
}
|
||||
#[derive(Deserialize, Serialize)]
|
||||
pub enum EnumType {
|
||||
BertPreTokenizer,
|
||||
ByteLevel,
|
||||
Delimiter,
|
||||
Metaspace,
|
||||
Whitespace,
|
||||
Sequence,
|
||||
Split,
|
||||
Punctuation,
|
||||
WhitespaceSplit,
|
||||
Digits,
|
||||
UnicodeScripts,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum PreTokenizerHelper {
|
||||
Tagged(Tagged),
|
||||
Legacy(serde_json::Value),
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum PreTokenizerUntagged {
|
||||
BertPreTokenizer(BertPreTokenizer),
|
||||
ByteLevel(ByteLevel),
|
||||
Delimiter(CharDelimiterSplit),
|
||||
Metaspace(Metaspace),
|
||||
Whitespace(Whitespace),
|
||||
Sequence(Sequence),
|
||||
Split(Split),
|
||||
Punctuation(Punctuation),
|
||||
WhitespaceSplit(WhitespaceSplit),
|
||||
Digits(Digits),
|
||||
UnicodeScripts(UnicodeScripts),
|
||||
}
|
||||
|
||||
let helper = PreTokenizerHelper::deserialize(deserializer)?;
|
||||
|
||||
Ok(match helper {
|
||||
PreTokenizerHelper::Tagged(pretok) => {
|
||||
let mut values: serde_json::Map<String, serde_json::Value> =
|
||||
serde_json::from_value(pretok.rest).map_err(serde::de::Error::custom)?;
|
||||
values.insert(
|
||||
"type".to_string(),
|
||||
serde_json::to_value(&pretok.variant).map_err(serde::de::Error::custom)?,
|
||||
);
|
||||
let values = serde_json::Value::Object(values);
|
||||
match pretok.variant {
|
||||
EnumType::BertPreTokenizer => PreTokenizerWrapper::BertPreTokenizer(
|
||||
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
|
||||
),
|
||||
EnumType::ByteLevel => PreTokenizerWrapper::ByteLevel(
|
||||
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
|
||||
),
|
||||
EnumType::Delimiter => PreTokenizerWrapper::Delimiter(
|
||||
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
|
||||
),
|
||||
EnumType::Metaspace => PreTokenizerWrapper::Metaspace(
|
||||
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
|
||||
),
|
||||
EnumType::Whitespace => PreTokenizerWrapper::Whitespace(
|
||||
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
|
||||
),
|
||||
EnumType::Sequence => PreTokenizerWrapper::Sequence(
|
||||
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
|
||||
),
|
||||
EnumType::Split => PreTokenizerWrapper::Split(
|
||||
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
|
||||
),
|
||||
EnumType::Punctuation => PreTokenizerWrapper::Punctuation(
|
||||
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
|
||||
),
|
||||
EnumType::WhitespaceSplit => PreTokenizerWrapper::WhitespaceSplit(
|
||||
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
|
||||
),
|
||||
EnumType::Digits => PreTokenizerWrapper::Digits(
|
||||
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
|
||||
),
|
||||
EnumType::UnicodeScripts => PreTokenizerWrapper::UnicodeScripts(
|
||||
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
PreTokenizerHelper::Legacy(value) => {
|
||||
let untagged = serde_json::from_value(value).map_err(serde::de::Error::custom)?;
|
||||
match untagged {
|
||||
PreTokenizerUntagged::BertPreTokenizer(bert) => {
|
||||
PreTokenizerWrapper::BertPreTokenizer(bert)
|
||||
}
|
||||
PreTokenizerUntagged::ByteLevel(byte_level) => {
|
||||
PreTokenizerWrapper::ByteLevel(byte_level)
|
||||
}
|
||||
PreTokenizerUntagged::Delimiter(delimiter) => {
|
||||
PreTokenizerWrapper::Delimiter(delimiter)
|
||||
}
|
||||
PreTokenizerUntagged::Metaspace(metaspace) => {
|
||||
PreTokenizerWrapper::Metaspace(metaspace)
|
||||
}
|
||||
PreTokenizerUntagged::Whitespace(whitespace) => {
|
||||
PreTokenizerWrapper::Whitespace(whitespace)
|
||||
}
|
||||
PreTokenizerUntagged::Sequence(sequence) => {
|
||||
PreTokenizerWrapper::Sequence(sequence)
|
||||
}
|
||||
PreTokenizerUntagged::Split(split) => PreTokenizerWrapper::Split(split),
|
||||
PreTokenizerUntagged::Punctuation(punctuation) => {
|
||||
PreTokenizerWrapper::Punctuation(punctuation)
|
||||
}
|
||||
PreTokenizerUntagged::WhitespaceSplit(whitespace_split) => {
|
||||
PreTokenizerWrapper::WhitespaceSplit(whitespace_split)
|
||||
}
|
||||
PreTokenizerUntagged::Digits(digits) => PreTokenizerWrapper::Digits(digits),
|
||||
PreTokenizerUntagged::UnicodeScripts(unicode_scripts) => {
|
||||
PreTokenizerWrapper::UnicodeScripts(unicode_scripts)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl_enum_from!(BertPreTokenizer, PreTokenizerWrapper, BertPreTokenizer);
|
||||
impl_enum_from!(ByteLevel, PreTokenizerWrapper, ByteLevel);
|
||||
impl_enum_from!(CharDelimiterSplit, PreTokenizerWrapper, Delimiter);
|
||||
@ -152,25 +288,22 @@ mod tests {
|
||||
match reconstructed {
|
||||
Err(err) => assert_eq!(
|
||||
err.to_string(),
|
||||
"data did not match any variant of untagged enum PreTokenizerWrapper"
|
||||
"data did not match any variant of untagged enum PreTokenizerUntagged"
|
||||
),
|
||||
_ => panic!("Expected an error here"),
|
||||
}
|
||||
|
||||
let json = r#"{"type":"Metaspace", "replacement":"▁" }"#;
|
||||
let reconstructed = serde_json::from_str::<PreTokenizerWrapper>(json);
|
||||
let reconstructed = serde_json::from_str::<PreTokenizerWrapper>(json).unwrap();
|
||||
assert_eq!(
|
||||
reconstructed.unwrap(),
|
||||
reconstructed,
|
||||
PreTokenizerWrapper::Metaspace(Metaspace::default())
|
||||
);
|
||||
|
||||
let json = r#"{"type":"Metaspace", "add_prefix_space":true }"#;
|
||||
let reconstructed = serde_json::from_str::<PreTokenizerWrapper>(json);
|
||||
match reconstructed {
|
||||
Err(err) => assert_eq!(
|
||||
err.to_string(),
|
||||
"data did not match any variant of untagged enum PreTokenizerWrapper"
|
||||
),
|
||||
Err(err) => assert_eq!(err.to_string(), "missing field `replacement`"),
|
||||
_ => panic!("Expected an error here"),
|
||||
}
|
||||
let json = r#"{"behavior":"default_split"}"#;
|
||||
@ -178,7 +311,7 @@ mod tests {
|
||||
match reconstructed {
|
||||
Err(err) => assert_eq!(
|
||||
err.to_string(),
|
||||
"data did not match any variant of untagged enum PreTokenizerWrapper"
|
||||
"data did not match any variant of untagged enum PreTokenizerUntagged"
|
||||
),
|
||||
_ => panic!("Expected an error here"),
|
||||
}
|
||||
|
Reference in New Issue
Block a user