mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Adding some serialization testing around the wrapper.
This commit is contained in:
@ -151,18 +151,27 @@ mod test {
|
||||
|
||||
#[test]
|
||||
fn test_serialization() {
|
||||
let vocab: Vocab = [("<unk>".into(), 0), ("a".into(), 1), ("b".into(), 2)]
|
||||
let vocab: Vocab = [
|
||||
("<unk>".into(), 0),
|
||||
("a".into(), 1),
|
||||
("b".into(), 2),
|
||||
("ab".into(), 3),
|
||||
]
|
||||
.iter()
|
||||
.cloned()
|
||||
.collect();
|
||||
let bpe = BpeBuilder::default()
|
||||
.vocab_and_merges(vocab, vec![])
|
||||
.vocab_and_merges(vocab, vec![("a".to_string(), "b".to_string())])
|
||||
.unk_token("<unk>".to_string())
|
||||
.ignore_merges(true)
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let data = serde_json::to_string(&bpe).unwrap();
|
||||
assert_eq!(
|
||||
data,
|
||||
r#"{"type":"BPE","dropout":null,"unk_token":"<unk>","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"<unk>":0,"a":1,"b":2,"ab":3},"merges":["a b"]}"#
|
||||
);
|
||||
let reconstructed = serde_json::from_str(&data).unwrap();
|
||||
|
||||
assert_eq!(bpe, reconstructed);
|
||||
|
@ -204,6 +204,7 @@ impl_enum_from!(WordLevelTrainer, TrainerWrapper, WordLevelTrainer);
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::models::bpe::{BpeBuilder, Vocab};
|
||||
|
||||
#[test]
|
||||
fn trainer_wrapper_train_model_wrapper() {
|
||||
@ -224,4 +225,45 @@ mod tests {
|
||||
let serialized = serde_json::to_string(&ordered).unwrap();
|
||||
assert_eq!(serialized, "{\"Hi\":0,\"There\":2}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serialization() {
|
||||
let vocab: Vocab = [
|
||||
("<unk>".into(), 0),
|
||||
("a".into(), 1),
|
||||
("b".into(), 2),
|
||||
("ab".into(), 3),
|
||||
]
|
||||
.iter()
|
||||
.cloned()
|
||||
.collect();
|
||||
let bpe = BpeBuilder::default()
|
||||
.vocab_and_merges(vocab, vec![("a".to_string(), "b".to_string())])
|
||||
.unk_token("<unk>".to_string())
|
||||
.ignore_merges(true)
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let model = ModelWrapper::BPE(bpe);
|
||||
|
||||
let data = serde_json::to_string(&model).unwrap();
|
||||
assert_eq!(
|
||||
data,
|
||||
r#"{"type":"BPE","dropout":null,"unk_token":"<unk>","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"<unk>":0,"a":1,"b":2,"ab":3},"merges":["a b"]}"#
|
||||
);
|
||||
let reconstructed = serde_json::from_str(&data).unwrap();
|
||||
|
||||
assert_eq!(model, reconstructed);
|
||||
|
||||
let invalid = r#"{"type":"BPE","dropout":null,"unk_token":"<unk>","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"<unk>":0,"a":1,"b":2,"ab":3},"merges":["a b c"]}"#;
|
||||
let reconstructed: std::result::Result<ModelWrapper, serde_json::Error> =
|
||||
serde_json::from_str(&invalid);
|
||||
match reconstructed {
|
||||
Err(err) => assert_eq!(
|
||||
err.to_string(),
|
||||
"data did not match any variant of untagged enum ModelWrapper"
|
||||
),
|
||||
_ => panic!("Expected an error here"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user