mirror of
https://github.com/mii443/tokenizers.git
synced 2025-09-01 23:09:34 +00:00
Tests + Deserialization improvement for normalizers. (#1604)
This commit is contained in:
@ -14,12 +14,12 @@ pub use crate::normalizers::replace::Replace;
|
|||||||
pub use crate::normalizers::strip::{Strip, StripAccents};
|
pub use crate::normalizers::strip::{Strip, StripAccents};
|
||||||
pub use crate::normalizers::unicode::{Nmt, NFC, NFD, NFKC, NFKD};
|
pub use crate::normalizers::unicode::{Nmt, NFC, NFD, NFKC, NFKD};
|
||||||
pub use crate::normalizers::utils::{Lowercase, Sequence};
|
pub use crate::normalizers::utils::{Lowercase, Sequence};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Deserializer, Serialize};
|
||||||
|
|
||||||
use crate::{NormalizedString, Normalizer};
|
use crate::{NormalizedString, Normalizer};
|
||||||
|
|
||||||
/// Wrapper for known Normalizers.
|
/// Wrapper for known Normalizers.
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
#[derive(Clone, Debug, Serialize)]
|
||||||
#[serde(untagged)]
|
#[serde(untagged)]
|
||||||
pub enum NormalizerWrapper {
|
pub enum NormalizerWrapper {
|
||||||
BertNormalizer(BertNormalizer),
|
BertNormalizer(BertNormalizer),
|
||||||
@ -38,6 +38,149 @@ pub enum NormalizerWrapper {
|
|||||||
ByteLevel(ByteLevel),
|
ByteLevel(ByteLevel),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<'de> Deserialize<'de> for NormalizerWrapper {
|
||||||
|
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
|
||||||
|
where
|
||||||
|
D: Deserializer<'de>,
|
||||||
|
{
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
pub struct Tagged {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
variant: EnumType,
|
||||||
|
#[serde(flatten)]
|
||||||
|
rest: serde_json::Value,
|
||||||
|
}
|
||||||
|
#[derive(Serialize, Deserialize)]
|
||||||
|
pub enum EnumType {
|
||||||
|
Bert,
|
||||||
|
Strip,
|
||||||
|
StripAccents,
|
||||||
|
NFC,
|
||||||
|
NFD,
|
||||||
|
NFKC,
|
||||||
|
NFKD,
|
||||||
|
Sequence,
|
||||||
|
Lowercase,
|
||||||
|
Nmt,
|
||||||
|
Precompiled,
|
||||||
|
Replace,
|
||||||
|
Prepend,
|
||||||
|
ByteLevel,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub enum NormalizerHelper {
|
||||||
|
Tagged(Tagged),
|
||||||
|
Legacy(serde_json::Value),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub enum NormalizerUntagged {
|
||||||
|
BertNormalizer(BertNormalizer),
|
||||||
|
StripNormalizer(Strip),
|
||||||
|
StripAccents(StripAccents),
|
||||||
|
NFC(NFC),
|
||||||
|
NFD(NFD),
|
||||||
|
NFKC(NFKC),
|
||||||
|
NFKD(NFKD),
|
||||||
|
Sequence(Sequence),
|
||||||
|
Lowercase(Lowercase),
|
||||||
|
Nmt(Nmt),
|
||||||
|
Precompiled(Precompiled),
|
||||||
|
Replace(Replace),
|
||||||
|
Prepend(Prepend),
|
||||||
|
ByteLevel(ByteLevel),
|
||||||
|
}
|
||||||
|
|
||||||
|
let helper = NormalizerHelper::deserialize(deserializer)?;
|
||||||
|
Ok(match helper {
|
||||||
|
NormalizerHelper::Tagged(model) => {
|
||||||
|
let mut values: serde_json::Map<String, serde_json::Value> =
|
||||||
|
serde_json::from_value(model.rest).expect("Parsed values");
|
||||||
|
values.insert(
|
||||||
|
"type".to_string(),
|
||||||
|
serde_json::to_value(&model.variant).expect("Reinsert"),
|
||||||
|
);
|
||||||
|
let values = serde_json::Value::Object(values);
|
||||||
|
match model.variant {
|
||||||
|
EnumType::Bert => NormalizerWrapper::BertNormalizer(
|
||||||
|
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
|
||||||
|
),
|
||||||
|
EnumType::Strip => NormalizerWrapper::StripNormalizer(
|
||||||
|
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
|
||||||
|
),
|
||||||
|
EnumType::StripAccents => NormalizerWrapper::StripAccents(
|
||||||
|
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
|
||||||
|
),
|
||||||
|
EnumType::NFC => NormalizerWrapper::NFC(
|
||||||
|
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
|
||||||
|
),
|
||||||
|
EnumType::NFD => NormalizerWrapper::NFD(
|
||||||
|
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
|
||||||
|
),
|
||||||
|
EnumType::NFKC => NormalizerWrapper::NFKC(
|
||||||
|
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
|
||||||
|
),
|
||||||
|
EnumType::NFKD => NormalizerWrapper::NFKD(
|
||||||
|
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
|
||||||
|
),
|
||||||
|
EnumType::Sequence => NormalizerWrapper::Sequence(
|
||||||
|
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
|
||||||
|
),
|
||||||
|
EnumType::Lowercase => NormalizerWrapper::Lowercase(
|
||||||
|
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
|
||||||
|
),
|
||||||
|
EnumType::Nmt => NormalizerWrapper::Nmt(
|
||||||
|
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
|
||||||
|
),
|
||||||
|
EnumType::Precompiled => NormalizerWrapper::Precompiled(
|
||||||
|
serde_json::from_str(
|
||||||
|
&serde_json::to_string(&values).expect("Can reserialize precompiled"),
|
||||||
|
)
|
||||||
|
// .map_err(serde::de::Error::custom)
|
||||||
|
.expect("Precompiled"),
|
||||||
|
),
|
||||||
|
EnumType::Replace => NormalizerWrapper::Replace(
|
||||||
|
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
|
||||||
|
),
|
||||||
|
EnumType::Prepend => NormalizerWrapper::Prepend(
|
||||||
|
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
|
||||||
|
),
|
||||||
|
EnumType::ByteLevel => NormalizerWrapper::ByteLevel(
|
||||||
|
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
NormalizerHelper::Legacy(value) => {
|
||||||
|
let untagged = serde_json::from_value(value).map_err(serde::de::Error::custom)?;
|
||||||
|
match untagged {
|
||||||
|
NormalizerUntagged::BertNormalizer(bpe) => {
|
||||||
|
NormalizerWrapper::BertNormalizer(bpe)
|
||||||
|
}
|
||||||
|
NormalizerUntagged::StripNormalizer(bpe) => {
|
||||||
|
NormalizerWrapper::StripNormalizer(bpe)
|
||||||
|
}
|
||||||
|
NormalizerUntagged::StripAccents(bpe) => NormalizerWrapper::StripAccents(bpe),
|
||||||
|
NormalizerUntagged::NFC(bpe) => NormalizerWrapper::NFC(bpe),
|
||||||
|
NormalizerUntagged::NFD(bpe) => NormalizerWrapper::NFD(bpe),
|
||||||
|
NormalizerUntagged::NFKC(bpe) => NormalizerWrapper::NFKC(bpe),
|
||||||
|
NormalizerUntagged::NFKD(bpe) => NormalizerWrapper::NFKD(bpe),
|
||||||
|
NormalizerUntagged::Sequence(bpe) => NormalizerWrapper::Sequence(bpe),
|
||||||
|
NormalizerUntagged::Lowercase(bpe) => NormalizerWrapper::Lowercase(bpe),
|
||||||
|
NormalizerUntagged::Nmt(bpe) => NormalizerWrapper::Nmt(bpe),
|
||||||
|
NormalizerUntagged::Precompiled(bpe) => NormalizerWrapper::Precompiled(bpe),
|
||||||
|
NormalizerUntagged::Replace(bpe) => NormalizerWrapper::Replace(bpe),
|
||||||
|
NormalizerUntagged::Prepend(bpe) => NormalizerWrapper::Prepend(bpe),
|
||||||
|
NormalizerUntagged::ByteLevel(bpe) => NormalizerWrapper::ByteLevel(bpe),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Normalizer for NormalizerWrapper {
|
impl Normalizer for NormalizerWrapper {
|
||||||
fn normalize(&self, normalized: &mut NormalizedString) -> crate::Result<()> {
|
fn normalize(&self, normalized: &mut NormalizedString) -> crate::Result<()> {
|
||||||
match self {
|
match self {
|
||||||
@ -91,7 +234,7 @@ mod tests {
|
|||||||
match reconstructed {
|
match reconstructed {
|
||||||
Err(err) => assert_eq!(
|
Err(err) => assert_eq!(
|
||||||
err.to_string(),
|
err.to_string(),
|
||||||
"data did not match any variant of untagged enum NormalizerWrapper"
|
"data did not match any variant of untagged enum NormalizerUntagged"
|
||||||
),
|
),
|
||||||
_ => panic!("Expected an error here"),
|
_ => panic!("Expected an error here"),
|
||||||
}
|
}
|
||||||
@ -103,4 +246,36 @@ mod tests {
|
|||||||
NormalizerWrapper::Prepend(_)
|
NormalizerWrapper::Prepend(_)
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn normalizer_serialization() {
|
||||||
|
let json = r#"{"type":"Sequence","normalizers":[]}"#;
|
||||||
|
assert!(serde_json::from_str::<NormalizerWrapper>(json).is_ok());
|
||||||
|
let json = r#"{"type":"Sequence","normalizers":[{}]}"#;
|
||||||
|
let parse = serde_json::from_str::<NormalizerWrapper>(json);
|
||||||
|
match parse {
|
||||||
|
Err(err) => assert_eq!(
|
||||||
|
format!("{err}"),
|
||||||
|
"data did not match any variant of untagged enum NormalizerUntagged"
|
||||||
|
),
|
||||||
|
_ => panic!("Expected error"),
|
||||||
|
}
|
||||||
|
|
||||||
|
let json = r#"{"replacement":"▁","prepend_scheme":"always"}"#;
|
||||||
|
let parse = serde_json::from_str::<NormalizerWrapper>(json);
|
||||||
|
match parse {
|
||||||
|
Err(err) => assert_eq!(
|
||||||
|
format!("{err}"),
|
||||||
|
"data did not match any variant of untagged enum NormalizerUntagged"
|
||||||
|
),
|
||||||
|
_ => panic!("Expected error"),
|
||||||
|
}
|
||||||
|
|
||||||
|
let json = r#"{"type":"Sequence","prepend_scheme":"always"}"#;
|
||||||
|
let parse = serde_json::from_str::<NormalizerWrapper>(json);
|
||||||
|
match parse {
|
||||||
|
Err(err) => assert_eq!(format!("{err}"), "missing field `normalizers`"),
|
||||||
|
_ => panic!("Expected error"),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user