mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Updating error messages. (#1599)
This commit is contained in:
@ -10,7 +10,7 @@ pub mod wordpiece;
|
||||
pub use super::pre_tokenizers::byte_level;
|
||||
pub use super::pre_tokenizers::metaspace;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde::{Deserialize, Deserializer, Serialize};
|
||||
|
||||
use crate::decoders::bpe::BPEDecoder;
|
||||
use crate::decoders::byte_fallback::ByteFallback;
|
||||
@ -24,7 +24,7 @@ use crate::pre_tokenizers::byte_level::ByteLevel;
|
||||
use crate::pre_tokenizers::metaspace::Metaspace;
|
||||
use crate::{Decoder, Result};
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
#[derive(Serialize, Clone, Debug)]
|
||||
#[serde(untagged)]
|
||||
pub enum DecoderWrapper {
|
||||
BPE(BPEDecoder),
|
||||
@ -39,6 +39,116 @@ pub enum DecoderWrapper {
|
||||
ByteFallback(ByteFallback),
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for DecoderWrapper {
|
||||
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
#[derive(Deserialize)]
|
||||
pub struct Tagged {
|
||||
#[serde(rename = "type")]
|
||||
variant: EnumType,
|
||||
#[serde(flatten)]
|
||||
rest: serde_json::Value,
|
||||
}
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub enum EnumType {
|
||||
BPEDecoder,
|
||||
ByteLevel,
|
||||
WordPiece,
|
||||
Metaspace,
|
||||
CTC,
|
||||
Sequence,
|
||||
Replace,
|
||||
Fuse,
|
||||
Strip,
|
||||
ByteFallback,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum DecoderHelper {
|
||||
Tagged(Tagged),
|
||||
Legacy(serde_json::Value),
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum DecoderUntagged {
|
||||
BPE(BPEDecoder),
|
||||
ByteLevel(ByteLevel),
|
||||
WordPiece(WordPiece),
|
||||
Metaspace(Metaspace),
|
||||
CTC(CTC),
|
||||
Sequence(Sequence),
|
||||
Replace(Replace),
|
||||
Fuse(Fuse),
|
||||
Strip(Strip),
|
||||
ByteFallback(ByteFallback),
|
||||
}
|
||||
|
||||
let helper = DecoderHelper::deserialize(deserializer).expect("Helper");
|
||||
Ok(match helper {
|
||||
DecoderHelper::Tagged(model) => {
|
||||
let mut values: serde_json::Map<String, serde_json::Value> =
|
||||
serde_json::from_value(model.rest).map_err(serde::de::Error::custom)?;
|
||||
values.insert(
|
||||
"type".to_string(),
|
||||
serde_json::to_value(&model.variant).map_err(serde::de::Error::custom)?,
|
||||
);
|
||||
let values = serde_json::Value::Object(values);
|
||||
match model.variant {
|
||||
EnumType::BPEDecoder => DecoderWrapper::BPE(
|
||||
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
|
||||
),
|
||||
EnumType::ByteLevel => DecoderWrapper::ByteLevel(
|
||||
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
|
||||
),
|
||||
EnumType::WordPiece => DecoderWrapper::WordPiece(
|
||||
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
|
||||
),
|
||||
EnumType::Metaspace => DecoderWrapper::Metaspace(
|
||||
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
|
||||
),
|
||||
EnumType::CTC => DecoderWrapper::CTC(
|
||||
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
|
||||
),
|
||||
EnumType::Sequence => DecoderWrapper::Sequence(
|
||||
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
|
||||
),
|
||||
EnumType::Replace => DecoderWrapper::Replace(
|
||||
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
|
||||
),
|
||||
EnumType::Fuse => DecoderWrapper::Fuse(
|
||||
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
|
||||
),
|
||||
EnumType::Strip => DecoderWrapper::Strip(
|
||||
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
|
||||
),
|
||||
EnumType::ByteFallback => DecoderWrapper::ByteFallback(
|
||||
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
|
||||
),
|
||||
}
|
||||
}
|
||||
DecoderHelper::Legacy(value) => {
|
||||
let untagged = serde_json::from_value(value).map_err(serde::de::Error::custom)?;
|
||||
match untagged {
|
||||
DecoderUntagged::BPE(dec) => DecoderWrapper::BPE(dec),
|
||||
DecoderUntagged::ByteLevel(dec) => DecoderWrapper::ByteLevel(dec),
|
||||
DecoderUntagged::WordPiece(dec) => DecoderWrapper::WordPiece(dec),
|
||||
DecoderUntagged::Metaspace(dec) => DecoderWrapper::Metaspace(dec),
|
||||
DecoderUntagged::CTC(dec) => DecoderWrapper::CTC(dec),
|
||||
DecoderUntagged::Sequence(dec) => DecoderWrapper::Sequence(dec),
|
||||
DecoderUntagged::Replace(dec) => DecoderWrapper::Replace(dec),
|
||||
DecoderUntagged::Fuse(dec) => DecoderWrapper::Fuse(dec),
|
||||
DecoderUntagged::Strip(dec) => DecoderWrapper::Strip(dec),
|
||||
DecoderUntagged::ByteFallback(dec) => DecoderWrapper::ByteFallback(dec),
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Decoder for DecoderWrapper {
|
||||
fn decode_chain(&self, tokens: Vec<String>) -> Result<Vec<String>> {
|
||||
match self {
|
||||
@ -98,7 +208,7 @@ mod tests {
|
||||
match parse {
|
||||
Err(err) => assert_eq!(
|
||||
format!("{err}"),
|
||||
"data did not match any variant of untagged enum DecoderWrapper"
|
||||
"data did not match any variant of untagged enum DecoderUntagged"
|
||||
),
|
||||
_ => panic!("Expected error"),
|
||||
}
|
||||
@ -108,7 +218,7 @@ mod tests {
|
||||
match parse {
|
||||
Err(err) => assert_eq!(
|
||||
format!("{err}"),
|
||||
"data did not match any variant of untagged enum DecoderWrapper"
|
||||
"data did not match any variant of untagged enum DecoderUntagged"
|
||||
),
|
||||
_ => panic!("Expected error"),
|
||||
}
|
||||
@ -116,10 +226,7 @@ mod tests {
|
||||
let json = r#"{"type":"Sequence","prepend_scheme":"always"}"#;
|
||||
let parse = serde_json::from_str::<DecoderWrapper>(json);
|
||||
match parse {
|
||||
Err(err) => assert_eq!(
|
||||
format!("{err}"),
|
||||
"data did not match any variant of untagged enum DecoderWrapper"
|
||||
),
|
||||
Err(err) => assert_eq!(format!("{err}"), "missing field `decoders`"),
|
||||
_ => panic!("Expected error"),
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user