Better serialization error (#1595)

* Updating the deserialization error for models.

* Update tokenizers/src/models/mod.rs

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Nicolas Patry
2024-08-06 13:39:11 +02:00
committed by GitHub
parent 2d27761f60
commit fe41687ca8

View File

@ -8,7 +8,7 @@ pub mod wordpiece;
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use serde::{Deserialize, Serialize, Serializer};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use crate::models::bpe::{BpeTrainer, BPE};
use crate::models::unigram::{Unigram, UnigramTrainer};
@ -57,7 +57,7 @@ impl<'a> Serialize for OrderedVocabIter<'a> {
}
}
#[derive(Deserialize, Serialize, Debug, PartialEq, Clone)]
#[derive(Serialize, Debug, PartialEq, Clone)]
#[serde(untagged)]
pub enum ModelWrapper {
BPE(BPE),
@ -68,6 +68,73 @@ pub enum ModelWrapper {
Unigram(Unigram),
}
impl<'de> Deserialize<'de> for ModelWrapper {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Deserialize)]
pub struct Tagged {
#[serde(rename = "type")]
variant: EnumType,
#[serde(flatten)]
rest: serde_json::Value,
}
#[derive(Deserialize)]
pub enum EnumType {
BPE,
WordPiece,
WordLevel,
Unigram,
}
#[derive(Deserialize)]
#[serde(untagged)]
pub enum ModelHelper {
Tagged(Tagged),
Legacy(serde_json::Value),
}
#[derive(Deserialize)]
#[serde(untagged)]
pub enum ModelUntagged {
BPE(BPE),
// WordPiece must stay before WordLevel here for deserialization (for retrocompatibility
// with the versions not including the "type"), since WordLevel is a subset of WordPiece
WordPiece(WordPiece),
WordLevel(WordLevel),
Unigram(Unigram),
}
let helper = ModelHelper::deserialize(deserializer)?;
Ok(match helper {
ModelHelper::Tagged(model) => match model.variant {
EnumType::BPE => ModelWrapper::BPE(
serde_json::from_value(model.rest).map_err(serde::de::Error::custom)?,
),
EnumType::WordPiece => ModelWrapper::WordPiece(
serde_json::from_value(model.rest).map_err(serde::de::Error::custom)?,
),
EnumType::WordLevel => ModelWrapper::WordLevel(
serde_json::from_value(model.rest).map_err(serde::de::Error::custom)?,
),
EnumType::Unigram => ModelWrapper::Unigram(
serde_json::from_value(model.rest).map_err(serde::de::Error::custom)?,
),
},
ModelHelper::Legacy(value) => {
let untagged = serde_json::from_value(value).map_err(serde::de::Error::custom)?;
match untagged {
ModelUntagged::BPE(bpe) => ModelWrapper::BPE(bpe),
ModelUntagged::WordPiece(bpe) => ModelWrapper::WordPiece(bpe),
ModelUntagged::WordLevel(bpe) => ModelWrapper::WordLevel(bpe),
ModelUntagged::Unigram(bpe) => ModelWrapper::Unigram(bpe),
}
}
})
}
}
impl_enum_from!(WordLevel, ModelWrapper, WordLevel);
impl_enum_from!(WordPiece, ModelWrapper, WordPiece);
impl_enum_from!(BPE, ModelWrapper, BPE);
@ -263,10 +330,7 @@ mod tests {
let reconstructed: std::result::Result<ModelWrapper, serde_json::Error> =
serde_json::from_str(invalid);
match reconstructed {
Err(err) => assert_eq!(
err.to_string(),
"data did not match any variant of untagged enum ModelWrapper"
),
Err(err) => assert_eq!(err.to_string(), "Merges text file invalid at line 1"),
_ => panic!("Expected an error here"),
}
}