mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-23 00:35:35 +00:00
Tests + Deserialization improvement for normalizers. (#1604)
This commit is contained in:
@ -14,12 +14,12 @@ pub use crate::normalizers::replace::Replace;
|
||||
pub use crate::normalizers::strip::{Strip, StripAccents};
|
||||
pub use crate::normalizers::unicode::{Nmt, NFC, NFD, NFKC, NFKD};
|
||||
pub use crate::normalizers::utils::{Lowercase, Sequence};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde::{Deserialize, Deserializer, Serialize};
|
||||
|
||||
use crate::{NormalizedString, Normalizer};
|
||||
|
||||
/// Wrapper for known Normalizers.
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
#[derive(Clone, Debug, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum NormalizerWrapper {
|
||||
BertNormalizer(BertNormalizer),
|
||||
@ -38,6 +38,149 @@ pub enum NormalizerWrapper {
|
||||
ByteLevel(ByteLevel),
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for NormalizerWrapper {
|
||||
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
#[derive(Deserialize)]
|
||||
pub struct Tagged {
|
||||
#[serde(rename = "type")]
|
||||
variant: EnumType,
|
||||
#[serde(flatten)]
|
||||
rest: serde_json::Value,
|
||||
}
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub enum EnumType {
|
||||
Bert,
|
||||
Strip,
|
||||
StripAccents,
|
||||
NFC,
|
||||
NFD,
|
||||
NFKC,
|
||||
NFKD,
|
||||
Sequence,
|
||||
Lowercase,
|
||||
Nmt,
|
||||
Precompiled,
|
||||
Replace,
|
||||
Prepend,
|
||||
ByteLevel,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum NormalizerHelper {
|
||||
Tagged(Tagged),
|
||||
Legacy(serde_json::Value),
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum NormalizerUntagged {
|
||||
BertNormalizer(BertNormalizer),
|
||||
StripNormalizer(Strip),
|
||||
StripAccents(StripAccents),
|
||||
NFC(NFC),
|
||||
NFD(NFD),
|
||||
NFKC(NFKC),
|
||||
NFKD(NFKD),
|
||||
Sequence(Sequence),
|
||||
Lowercase(Lowercase),
|
||||
Nmt(Nmt),
|
||||
Precompiled(Precompiled),
|
||||
Replace(Replace),
|
||||
Prepend(Prepend),
|
||||
ByteLevel(ByteLevel),
|
||||
}
|
||||
|
||||
let helper = NormalizerHelper::deserialize(deserializer)?;
|
||||
Ok(match helper {
|
||||
NormalizerHelper::Tagged(model) => {
|
||||
let mut values: serde_json::Map<String, serde_json::Value> =
|
||||
serde_json::from_value(model.rest).expect("Parsed values");
|
||||
values.insert(
|
||||
"type".to_string(),
|
||||
serde_json::to_value(&model.variant).expect("Reinsert"),
|
||||
);
|
||||
let values = serde_json::Value::Object(values);
|
||||
match model.variant {
|
||||
EnumType::Bert => NormalizerWrapper::BertNormalizer(
|
||||
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
|
||||
),
|
||||
EnumType::Strip => NormalizerWrapper::StripNormalizer(
|
||||
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
|
||||
),
|
||||
EnumType::StripAccents => NormalizerWrapper::StripAccents(
|
||||
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
|
||||
),
|
||||
EnumType::NFC => NormalizerWrapper::NFC(
|
||||
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
|
||||
),
|
||||
EnumType::NFD => NormalizerWrapper::NFD(
|
||||
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
|
||||
),
|
||||
EnumType::NFKC => NormalizerWrapper::NFKC(
|
||||
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
|
||||
),
|
||||
EnumType::NFKD => NormalizerWrapper::NFKD(
|
||||
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
|
||||
),
|
||||
EnumType::Sequence => NormalizerWrapper::Sequence(
|
||||
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
|
||||
),
|
||||
EnumType::Lowercase => NormalizerWrapper::Lowercase(
|
||||
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
|
||||
),
|
||||
EnumType::Nmt => NormalizerWrapper::Nmt(
|
||||
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
|
||||
),
|
||||
EnumType::Precompiled => NormalizerWrapper::Precompiled(
|
||||
serde_json::from_str(
|
||||
&serde_json::to_string(&values).expect("Can reserialize precompiled"),
|
||||
)
|
||||
// .map_err(serde::de::Error::custom)
|
||||
.expect("Precompiled"),
|
||||
),
|
||||
EnumType::Replace => NormalizerWrapper::Replace(
|
||||
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
|
||||
),
|
||||
EnumType::Prepend => NormalizerWrapper::Prepend(
|
||||
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
|
||||
),
|
||||
EnumType::ByteLevel => NormalizerWrapper::ByteLevel(
|
||||
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
NormalizerHelper::Legacy(value) => {
|
||||
let untagged = serde_json::from_value(value).map_err(serde::de::Error::custom)?;
|
||||
match untagged {
|
||||
NormalizerUntagged::BertNormalizer(bpe) => {
|
||||
NormalizerWrapper::BertNormalizer(bpe)
|
||||
}
|
||||
NormalizerUntagged::StripNormalizer(bpe) => {
|
||||
NormalizerWrapper::StripNormalizer(bpe)
|
||||
}
|
||||
NormalizerUntagged::StripAccents(bpe) => NormalizerWrapper::StripAccents(bpe),
|
||||
NormalizerUntagged::NFC(bpe) => NormalizerWrapper::NFC(bpe),
|
||||
NormalizerUntagged::NFD(bpe) => NormalizerWrapper::NFD(bpe),
|
||||
NormalizerUntagged::NFKC(bpe) => NormalizerWrapper::NFKC(bpe),
|
||||
NormalizerUntagged::NFKD(bpe) => NormalizerWrapper::NFKD(bpe),
|
||||
NormalizerUntagged::Sequence(bpe) => NormalizerWrapper::Sequence(bpe),
|
||||
NormalizerUntagged::Lowercase(bpe) => NormalizerWrapper::Lowercase(bpe),
|
||||
NormalizerUntagged::Nmt(bpe) => NormalizerWrapper::Nmt(bpe),
|
||||
NormalizerUntagged::Precompiled(bpe) => NormalizerWrapper::Precompiled(bpe),
|
||||
NormalizerUntagged::Replace(bpe) => NormalizerWrapper::Replace(bpe),
|
||||
NormalizerUntagged::Prepend(bpe) => NormalizerWrapper::Prepend(bpe),
|
||||
NormalizerUntagged::ByteLevel(bpe) => NormalizerWrapper::ByteLevel(bpe),
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Normalizer for NormalizerWrapper {
|
||||
fn normalize(&self, normalized: &mut NormalizedString) -> crate::Result<()> {
|
||||
match self {
|
||||
@ -91,7 +234,7 @@ mod tests {
|
||||
match reconstructed {
|
||||
Err(err) => assert_eq!(
|
||||
err.to_string(),
|
||||
"data did not match any variant of untagged enum NormalizerWrapper"
|
||||
"data did not match any variant of untagged enum NormalizerUntagged"
|
||||
),
|
||||
_ => panic!("Expected an error here"),
|
||||
}
|
||||
@ -103,4 +246,36 @@ mod tests {
|
||||
NormalizerWrapper::Prepend(_)
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn normalizer_serialization() {
|
||||
let json = r#"{"type":"Sequence","normalizers":[]}"#;
|
||||
assert!(serde_json::from_str::<NormalizerWrapper>(json).is_ok());
|
||||
let json = r#"{"type":"Sequence","normalizers":[{}]}"#;
|
||||
let parse = serde_json::from_str::<NormalizerWrapper>(json);
|
||||
match parse {
|
||||
Err(err) => assert_eq!(
|
||||
format!("{err}"),
|
||||
"data did not match any variant of untagged enum NormalizerUntagged"
|
||||
),
|
||||
_ => panic!("Expected error"),
|
||||
}
|
||||
|
||||
let json = r#"{"replacement":"▁","prepend_scheme":"always"}"#;
|
||||
let parse = serde_json::from_str::<NormalizerWrapper>(json);
|
||||
match parse {
|
||||
Err(err) => assert_eq!(
|
||||
format!("{err}"),
|
||||
"data did not match any variant of untagged enum NormalizerUntagged"
|
||||
),
|
||||
_ => panic!("Expected error"),
|
||||
}
|
||||
|
||||
let json = r#"{"type":"Sequence","prepend_scheme":"always"}"#;
|
||||
let parse = serde_json::from_str::<NormalizerWrapper>(json);
|
||||
match parse {
|
||||
Err(err) => assert_eq!(format!("{err}"), "missing field `normalizers`"),
|
||||
_ => panic!("Expected error"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user