mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Better serialization error (#1595)
* Updating the deserialization error for models. * Update tokenizers/src/models/mod.rs Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@ -8,7 +8,7 @@ pub mod wordpiece;
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize, Serializer};
|
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||||
|
|
||||||
use crate::models::bpe::{BpeTrainer, BPE};
|
use crate::models::bpe::{BpeTrainer, BPE};
|
||||||
use crate::models::unigram::{Unigram, UnigramTrainer};
|
use crate::models::unigram::{Unigram, UnigramTrainer};
|
||||||
@ -57,7 +57,7 @@ impl<'a> Serialize for OrderedVocabIter<'a> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize, Serialize, Debug, PartialEq, Clone)]
|
#[derive(Serialize, Debug, PartialEq, Clone)]
|
||||||
#[serde(untagged)]
|
#[serde(untagged)]
|
||||||
pub enum ModelWrapper {
|
pub enum ModelWrapper {
|
||||||
BPE(BPE),
|
BPE(BPE),
|
||||||
@ -68,6 +68,73 @@ pub enum ModelWrapper {
|
|||||||
Unigram(Unigram),
|
Unigram(Unigram),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<'de> Deserialize<'de> for ModelWrapper {
|
||||||
|
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
|
||||||
|
where
|
||||||
|
D: Deserializer<'de>,
|
||||||
|
{
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
pub struct Tagged {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
variant: EnumType,
|
||||||
|
#[serde(flatten)]
|
||||||
|
rest: serde_json::Value,
|
||||||
|
}
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
pub enum EnumType {
|
||||||
|
BPE,
|
||||||
|
WordPiece,
|
||||||
|
WordLevel,
|
||||||
|
Unigram,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub enum ModelHelper {
|
||||||
|
Tagged(Tagged),
|
||||||
|
Legacy(serde_json::Value),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub enum ModelUntagged {
|
||||||
|
BPE(BPE),
|
||||||
|
// WordPiece must stay before WordLevel here for deserialization (for retrocompatibility
|
||||||
|
// with the versions not including the "type"), since WordLevel is a subset of WordPiece
|
||||||
|
WordPiece(WordPiece),
|
||||||
|
WordLevel(WordLevel),
|
||||||
|
Unigram(Unigram),
|
||||||
|
}
|
||||||
|
|
||||||
|
let helper = ModelHelper::deserialize(deserializer)?;
|
||||||
|
Ok(match helper {
|
||||||
|
ModelHelper::Tagged(model) => match model.variant {
|
||||||
|
EnumType::BPE => ModelWrapper::BPE(
|
||||||
|
serde_json::from_value(model.rest).map_err(serde::de::Error::custom)?,
|
||||||
|
),
|
||||||
|
EnumType::WordPiece => ModelWrapper::WordPiece(
|
||||||
|
serde_json::from_value(model.rest).map_err(serde::de::Error::custom)?,
|
||||||
|
),
|
||||||
|
EnumType::WordLevel => ModelWrapper::WordLevel(
|
||||||
|
serde_json::from_value(model.rest).map_err(serde::de::Error::custom)?,
|
||||||
|
),
|
||||||
|
EnumType::Unigram => ModelWrapper::Unigram(
|
||||||
|
serde_json::from_value(model.rest).map_err(serde::de::Error::custom)?,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
ModelHelper::Legacy(value) => {
|
||||||
|
let untagged = serde_json::from_value(value).map_err(serde::de::Error::custom)?;
|
||||||
|
match untagged {
|
||||||
|
ModelUntagged::BPE(bpe) => ModelWrapper::BPE(bpe),
|
||||||
|
ModelUntagged::WordPiece(bpe) => ModelWrapper::WordPiece(bpe),
|
||||||
|
ModelUntagged::WordLevel(bpe) => ModelWrapper::WordLevel(bpe),
|
||||||
|
ModelUntagged::Unigram(bpe) => ModelWrapper::Unigram(bpe),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl_enum_from!(WordLevel, ModelWrapper, WordLevel);
|
impl_enum_from!(WordLevel, ModelWrapper, WordLevel);
|
||||||
impl_enum_from!(WordPiece, ModelWrapper, WordPiece);
|
impl_enum_from!(WordPiece, ModelWrapper, WordPiece);
|
||||||
impl_enum_from!(BPE, ModelWrapper, BPE);
|
impl_enum_from!(BPE, ModelWrapper, BPE);
|
||||||
@ -263,10 +330,7 @@ mod tests {
|
|||||||
let reconstructed: std::result::Result<ModelWrapper, serde_json::Error> =
|
let reconstructed: std::result::Result<ModelWrapper, serde_json::Error> =
|
||||||
serde_json::from_str(invalid);
|
serde_json::from_str(invalid);
|
||||||
match reconstructed {
|
match reconstructed {
|
||||||
Err(err) => assert_eq!(
|
Err(err) => assert_eq!(err.to_string(), "Merges text file invalid at line 1"),
|
||||||
err.to_string(),
|
|
||||||
"data did not match any variant of untagged enum ModelWrapper"
|
|
||||||
),
|
|
||||||
_ => panic!("Expected an error here"),
|
_ => panic!("Expected an error here"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user