Adding some serialization testing around the wrapper.

This commit is contained in:
Nicolas Patry
2024-08-06 09:55:01 +02:00
parent 7b80359dd2
commit 388014fd6b
2 changed files with 56 additions and 5 deletions

View File

@ -151,18 +151,27 @@ mod test {
#[test] #[test]
fn test_serialization() { fn test_serialization() {
let vocab: Vocab = [("<unk>".into(), 0), ("a".into(), 1), ("b".into(), 2)] let vocab: Vocab = [
.iter() ("<unk>".into(), 0),
.cloned() ("a".into(), 1),
.collect(); ("b".into(), 2),
("ab".into(), 3),
]
.iter()
.cloned()
.collect();
let bpe = BpeBuilder::default() let bpe = BpeBuilder::default()
.vocab_and_merges(vocab, vec![]) .vocab_and_merges(vocab, vec![("a".to_string(), "b".to_string())])
.unk_token("<unk>".to_string()) .unk_token("<unk>".to_string())
.ignore_merges(true) .ignore_merges(true)
.build() .build()
.unwrap(); .unwrap();
let data = serde_json::to_string(&bpe).unwrap(); let data = serde_json::to_string(&bpe).unwrap();
assert_eq!(
data,
r#"{"type":"BPE","dropout":null,"unk_token":"<unk>","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"<unk>":0,"a":1,"b":2,"ab":3},"merges":["a b"]}"#
);
let reconstructed = serde_json::from_str(&data).unwrap(); let reconstructed = serde_json::from_str(&data).unwrap();
assert_eq!(bpe, reconstructed); assert_eq!(bpe, reconstructed);

View File

@ -204,6 +204,7 @@ impl_enum_from!(WordLevelTrainer, TrainerWrapper, WordLevelTrainer);
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::models::bpe::{BpeBuilder, Vocab};
#[test] #[test]
fn trainer_wrapper_train_model_wrapper() { fn trainer_wrapper_train_model_wrapper() {
@ -224,4 +225,45 @@ mod tests {
let serialized = serde_json::to_string(&ordered).unwrap(); let serialized = serde_json::to_string(&ordered).unwrap();
assert_eq!(serialized, "{\"Hi\":0,\"There\":2}"); assert_eq!(serialized, "{\"Hi\":0,\"There\":2}");
} }
#[test]
fn serialization() {
let vocab: Vocab = [
("<unk>".into(), 0),
("a".into(), 1),
("b".into(), 2),
("ab".into(), 3),
]
.iter()
.cloned()
.collect();
let bpe = BpeBuilder::default()
.vocab_and_merges(vocab, vec![("a".to_string(), "b".to_string())])
.unk_token("<unk>".to_string())
.ignore_merges(true)
.build()
.unwrap();
let model = ModelWrapper::BPE(bpe);
let data = serde_json::to_string(&model).unwrap();
assert_eq!(
data,
r#"{"type":"BPE","dropout":null,"unk_token":"<unk>","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"<unk>":0,"a":1,"b":2,"ab":3},"merges":["a b"]}"#
);
let reconstructed = serde_json::from_str(&data).unwrap();
assert_eq!(model, reconstructed);
let invalid = r#"{"type":"BPE","dropout":null,"unk_token":"<unk>","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"<unk>":0,"a":1,"b":2,"ab":3},"merges":["a b c"]}"#;
let reconstructed: std::result::Result<ModelWrapper, serde_json::Error> =
serde_json::from_str(&invalid);
match reconstructed {
Err(err) => assert_eq!(
err.to_string(),
"data did not match any variant of untagged enum ModelWrapper"
),
_ => panic!("Expected an error here"),
}
}
} }