mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-23 00:35:35 +00:00
Updating error messages. (#1599)
This commit is contained in:
@ -10,7 +10,7 @@ pub mod wordpiece;
|
|||||||
pub use super::pre_tokenizers::byte_level;
|
pub use super::pre_tokenizers::byte_level;
|
||||||
pub use super::pre_tokenizers::metaspace;
|
pub use super::pre_tokenizers::metaspace;
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Deserializer, Serialize};
|
||||||
|
|
||||||
use crate::decoders::bpe::BPEDecoder;
|
use crate::decoders::bpe::BPEDecoder;
|
||||||
use crate::decoders::byte_fallback::ByteFallback;
|
use crate::decoders::byte_fallback::ByteFallback;
|
||||||
@ -24,7 +24,7 @@ use crate::pre_tokenizers::byte_level::ByteLevel;
|
|||||||
use crate::pre_tokenizers::metaspace::Metaspace;
|
use crate::pre_tokenizers::metaspace::Metaspace;
|
||||||
use crate::{Decoder, Result};
|
use crate::{Decoder, Result};
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
#[derive(Serialize, Clone, Debug)]
|
||||||
#[serde(untagged)]
|
#[serde(untagged)]
|
||||||
pub enum DecoderWrapper {
|
pub enum DecoderWrapper {
|
||||||
BPE(BPEDecoder),
|
BPE(BPEDecoder),
|
||||||
@ -39,6 +39,116 @@ pub enum DecoderWrapper {
|
|||||||
ByteFallback(ByteFallback),
|
ByteFallback(ByteFallback),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<'de> Deserialize<'de> for DecoderWrapper {
|
||||||
|
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
|
||||||
|
where
|
||||||
|
D: Deserializer<'de>,
|
||||||
|
{
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
pub struct Tagged {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
variant: EnumType,
|
||||||
|
#[serde(flatten)]
|
||||||
|
rest: serde_json::Value,
|
||||||
|
}
|
||||||
|
#[derive(Serialize, Deserialize)]
|
||||||
|
pub enum EnumType {
|
||||||
|
BPEDecoder,
|
||||||
|
ByteLevel,
|
||||||
|
WordPiece,
|
||||||
|
Metaspace,
|
||||||
|
CTC,
|
||||||
|
Sequence,
|
||||||
|
Replace,
|
||||||
|
Fuse,
|
||||||
|
Strip,
|
||||||
|
ByteFallback,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub enum DecoderHelper {
|
||||||
|
Tagged(Tagged),
|
||||||
|
Legacy(serde_json::Value),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub enum DecoderUntagged {
|
||||||
|
BPE(BPEDecoder),
|
||||||
|
ByteLevel(ByteLevel),
|
||||||
|
WordPiece(WordPiece),
|
||||||
|
Metaspace(Metaspace),
|
||||||
|
CTC(CTC),
|
||||||
|
Sequence(Sequence),
|
||||||
|
Replace(Replace),
|
||||||
|
Fuse(Fuse),
|
||||||
|
Strip(Strip),
|
||||||
|
ByteFallback(ByteFallback),
|
||||||
|
}
|
||||||
|
|
||||||
|
let helper = DecoderHelper::deserialize(deserializer).expect("Helper");
|
||||||
|
Ok(match helper {
|
||||||
|
DecoderHelper::Tagged(model) => {
|
||||||
|
let mut values: serde_json::Map<String, serde_json::Value> =
|
||||||
|
serde_json::from_value(model.rest).map_err(serde::de::Error::custom)?;
|
||||||
|
values.insert(
|
||||||
|
"type".to_string(),
|
||||||
|
serde_json::to_value(&model.variant).map_err(serde::de::Error::custom)?,
|
||||||
|
);
|
||||||
|
let values = serde_json::Value::Object(values);
|
||||||
|
match model.variant {
|
||||||
|
EnumType::BPEDecoder => DecoderWrapper::BPE(
|
||||||
|
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
|
||||||
|
),
|
||||||
|
EnumType::ByteLevel => DecoderWrapper::ByteLevel(
|
||||||
|
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
|
||||||
|
),
|
||||||
|
EnumType::WordPiece => DecoderWrapper::WordPiece(
|
||||||
|
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
|
||||||
|
),
|
||||||
|
EnumType::Metaspace => DecoderWrapper::Metaspace(
|
||||||
|
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
|
||||||
|
),
|
||||||
|
EnumType::CTC => DecoderWrapper::CTC(
|
||||||
|
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
|
||||||
|
),
|
||||||
|
EnumType::Sequence => DecoderWrapper::Sequence(
|
||||||
|
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
|
||||||
|
),
|
||||||
|
EnumType::Replace => DecoderWrapper::Replace(
|
||||||
|
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
|
||||||
|
),
|
||||||
|
EnumType::Fuse => DecoderWrapper::Fuse(
|
||||||
|
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
|
||||||
|
),
|
||||||
|
EnumType::Strip => DecoderWrapper::Strip(
|
||||||
|
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
|
||||||
|
),
|
||||||
|
EnumType::ByteFallback => DecoderWrapper::ByteFallback(
|
||||||
|
serde_json::from_value(values).map_err(serde::de::Error::custom)?,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
DecoderHelper::Legacy(value) => {
|
||||||
|
let untagged = serde_json::from_value(value).map_err(serde::de::Error::custom)?;
|
||||||
|
match untagged {
|
||||||
|
DecoderUntagged::BPE(dec) => DecoderWrapper::BPE(dec),
|
||||||
|
DecoderUntagged::ByteLevel(dec) => DecoderWrapper::ByteLevel(dec),
|
||||||
|
DecoderUntagged::WordPiece(dec) => DecoderWrapper::WordPiece(dec),
|
||||||
|
DecoderUntagged::Metaspace(dec) => DecoderWrapper::Metaspace(dec),
|
||||||
|
DecoderUntagged::CTC(dec) => DecoderWrapper::CTC(dec),
|
||||||
|
DecoderUntagged::Sequence(dec) => DecoderWrapper::Sequence(dec),
|
||||||
|
DecoderUntagged::Replace(dec) => DecoderWrapper::Replace(dec),
|
||||||
|
DecoderUntagged::Fuse(dec) => DecoderWrapper::Fuse(dec),
|
||||||
|
DecoderUntagged::Strip(dec) => DecoderWrapper::Strip(dec),
|
||||||
|
DecoderUntagged::ByteFallback(dec) => DecoderWrapper::ByteFallback(dec),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Decoder for DecoderWrapper {
|
impl Decoder for DecoderWrapper {
|
||||||
fn decode_chain(&self, tokens: Vec<String>) -> Result<Vec<String>> {
|
fn decode_chain(&self, tokens: Vec<String>) -> Result<Vec<String>> {
|
||||||
match self {
|
match self {
|
||||||
@ -98,7 +208,7 @@ mod tests {
|
|||||||
match parse {
|
match parse {
|
||||||
Err(err) => assert_eq!(
|
Err(err) => assert_eq!(
|
||||||
format!("{err}"),
|
format!("{err}"),
|
||||||
"data did not match any variant of untagged enum DecoderWrapper"
|
"data did not match any variant of untagged enum DecoderUntagged"
|
||||||
),
|
),
|
||||||
_ => panic!("Expected error"),
|
_ => panic!("Expected error"),
|
||||||
}
|
}
|
||||||
@ -108,7 +218,7 @@ mod tests {
|
|||||||
match parse {
|
match parse {
|
||||||
Err(err) => assert_eq!(
|
Err(err) => assert_eq!(
|
||||||
format!("{err}"),
|
format!("{err}"),
|
||||||
"data did not match any variant of untagged enum DecoderWrapper"
|
"data did not match any variant of untagged enum DecoderUntagged"
|
||||||
),
|
),
|
||||||
_ => panic!("Expected error"),
|
_ => panic!("Expected error"),
|
||||||
}
|
}
|
||||||
@ -116,10 +226,7 @@ mod tests {
|
|||||||
let json = r#"{"type":"Sequence","prepend_scheme":"always"}"#;
|
let json = r#"{"type":"Sequence","prepend_scheme":"always"}"#;
|
||||||
let parse = serde_json::from_str::<DecoderWrapper>(json);
|
let parse = serde_json::from_str::<DecoderWrapper>(json);
|
||||||
match parse {
|
match parse {
|
||||||
Err(err) => assert_eq!(
|
Err(err) => assert_eq!(format!("{err}"), "missing field `decoders`"),
|
||||||
format!("{err}"),
|
|
||||||
"data did not match any variant of untagged enum DecoderWrapper"
|
|
||||||
),
|
|
||||||
_ => panic!("Expected error"),
|
_ => panic!("Expected error"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user