diff --git a/tokenizers/src/models/bpe/serialization.rs b/tokenizers/src/models/bpe/serialization.rs index ca2590f1..152586ef 100644 --- a/tokenizers/src/models/bpe/serialization.rs +++ b/tokenizers/src/models/bpe/serialization.rs @@ -151,18 +151,27 @@ mod test { #[test] fn test_serialization() { - let vocab: Vocab = [("".into(), 0), ("a".into(), 1), ("b".into(), 2)] - .iter() - .cloned() - .collect(); + let vocab: Vocab = [ + ("".into(), 0), + ("a".into(), 1), + ("b".into(), 2), + ("ab".into(), 3), + ] + .iter() + .cloned() + .collect(); let bpe = BpeBuilder::default() - .vocab_and_merges(vocab, vec![]) + .vocab_and_merges(vocab, vec![("a".to_string(), "b".to_string())]) .unk_token("".to_string()) .ignore_merges(true) .build() .unwrap(); let data = serde_json::to_string(&bpe).unwrap(); + assert_eq!( + data, + r#"{"type":"BPE","dropout":null,"unk_token":"","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"":0,"a":1,"b":2,"ab":3},"merges":["a b"]}"# + ); let reconstructed = serde_json::from_str(&data).unwrap(); assert_eq!(bpe, reconstructed); diff --git a/tokenizers/src/models/mod.rs b/tokenizers/src/models/mod.rs index bb7cebc4..9dd3dc96 100644 --- a/tokenizers/src/models/mod.rs +++ b/tokenizers/src/models/mod.rs @@ -204,6 +204,7 @@ impl_enum_from!(WordLevelTrainer, TrainerWrapper, WordLevelTrainer); #[cfg(test)] mod tests { use super::*; + use crate::models::bpe::{BpeBuilder, Vocab}; #[test] fn trainer_wrapper_train_model_wrapper() { @@ -224,4 +225,45 @@ mod tests { let serialized = serde_json::to_string(&ordered).unwrap(); assert_eq!(serialized, "{\"Hi\":0,\"There\":2}"); } + + #[test] + fn serialization() { + let vocab: Vocab = [ + ("".into(), 0), + ("a".into(), 1), + ("b".into(), 2), + ("ab".into(), 3), + ] + .iter() + .cloned() + .collect(); + let bpe = BpeBuilder::default() + .vocab_and_merges(vocab, vec![("a".to_string(), "b".to_string())]) + .unk_token("".to_string()) + .ignore_merges(true) + .build() + .unwrap(); + + let model = ModelWrapper::BPE(bpe); + + let data = serde_json::to_string(&model).unwrap(); + assert_eq!( + data, + r#"{"type":"BPE","dropout":null,"unk_token":"","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"":0,"a":1,"b":2,"ab":3},"merges":["a b"]}"# + ); + let reconstructed = serde_json::from_str(&data).unwrap(); + + assert_eq!(model, reconstructed); + + let invalid = r#"{"type":"BPE","dropout":null,"unk_token":"","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"":0,"a":1,"b":2,"ab":3},"merges":["a b c"]}"#; + let reconstructed: std::result::Result = + serde_json::from_str(&invalid); + match reconstructed { + Err(err) => assert_eq!( + err.to_string(), + "data did not match any variant of untagged enum ModelWrapper" + ), + _ => panic!("Expected an error here"), + } + } }