mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Better serialization error (#1595)
* Updating the deserialization error for models. * Update tokenizers/src/models/mod.rs Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@ -8,7 +8,7 @@ pub mod wordpiece;
|
||||
use std::collections::HashMap;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use serde::{Deserialize, Serialize, Serializer};
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
|
||||
use crate::models::bpe::{BpeTrainer, BPE};
|
||||
use crate::models::unigram::{Unigram, UnigramTrainer};
|
||||
@ -57,7 +57,7 @@ impl<'a> Serialize for OrderedVocabIter<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Debug, PartialEq, Clone)]
|
||||
#[derive(Serialize, Debug, PartialEq, Clone)]
|
||||
#[serde(untagged)]
|
||||
pub enum ModelWrapper {
|
||||
BPE(BPE),
|
||||
@ -68,6 +68,73 @@ pub enum ModelWrapper {
|
||||
Unigram(Unigram),
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for ModelWrapper {
|
||||
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
#[derive(Deserialize)]
|
||||
pub struct Tagged {
|
||||
#[serde(rename = "type")]
|
||||
variant: EnumType,
|
||||
#[serde(flatten)]
|
||||
rest: serde_json::Value,
|
||||
}
|
||||
#[derive(Deserialize)]
|
||||
pub enum EnumType {
|
||||
BPE,
|
||||
WordPiece,
|
||||
WordLevel,
|
||||
Unigram,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum ModelHelper {
|
||||
Tagged(Tagged),
|
||||
Legacy(serde_json::Value),
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum ModelUntagged {
|
||||
BPE(BPE),
|
||||
// WordPiece must stay before WordLevel here for deserialization (for retrocompatibility
|
||||
// with the versions not including the "type"), since WordLevel is a subset of WordPiece
|
||||
WordPiece(WordPiece),
|
||||
WordLevel(WordLevel),
|
||||
Unigram(Unigram),
|
||||
}
|
||||
|
||||
let helper = ModelHelper::deserialize(deserializer)?;
|
||||
Ok(match helper {
|
||||
ModelHelper::Tagged(model) => match model.variant {
|
||||
EnumType::BPE => ModelWrapper::BPE(
|
||||
serde_json::from_value(model.rest).map_err(serde::de::Error::custom)?,
|
||||
),
|
||||
EnumType::WordPiece => ModelWrapper::WordPiece(
|
||||
serde_json::from_value(model.rest).map_err(serde::de::Error::custom)?,
|
||||
),
|
||||
EnumType::WordLevel => ModelWrapper::WordLevel(
|
||||
serde_json::from_value(model.rest).map_err(serde::de::Error::custom)?,
|
||||
),
|
||||
EnumType::Unigram => ModelWrapper::Unigram(
|
||||
serde_json::from_value(model.rest).map_err(serde::de::Error::custom)?,
|
||||
),
|
||||
},
|
||||
ModelHelper::Legacy(value) => {
|
||||
let untagged = serde_json::from_value(value).map_err(serde::de::Error::custom)?;
|
||||
match untagged {
|
||||
ModelUntagged::BPE(bpe) => ModelWrapper::BPE(bpe),
|
||||
ModelUntagged::WordPiece(bpe) => ModelWrapper::WordPiece(bpe),
|
||||
ModelUntagged::WordLevel(bpe) => ModelWrapper::WordLevel(bpe),
|
||||
ModelUntagged::Unigram(bpe) => ModelWrapper::Unigram(bpe),
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl_enum_from!(WordLevel, ModelWrapper, WordLevel);
|
||||
impl_enum_from!(WordPiece, ModelWrapper, WordPiece);
|
||||
impl_enum_from!(BPE, ModelWrapper, BPE);
|
||||
@ -263,10 +330,7 @@ mod tests {
|
||||
let reconstructed: std::result::Result<ModelWrapper, serde_json::Error> =
|
||||
serde_json::from_str(invalid);
|
||||
match reconstructed {
|
||||
Err(err) => assert_eq!(
|
||||
err.to_string(),
|
||||
"data did not match any variant of untagged enum ModelWrapper"
|
||||
),
|
||||
Err(err) => assert_eq!(err.to_string(), "Merges text file invalid at line 1"),
|
||||
_ => panic!("Expected an error here"),
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user