Fixing issue with ConvBert not being able to save because of of holes in (#954)

the vocab.
This commit is contained in:
Nicolas Patry
2022-03-21 19:28:49 +01:00
committed by GitHub
parent 1bb9884f45
commit cd730594e9
4 changed files with 45 additions and 6 deletions

View File

@ -1743,6 +1743,7 @@ dependencies = [
"serde",
"serde_json",
"spm_precompiled",
"thiserror",
"unicode-normalization-alignments",
"unicode-segmentation",
"unicode_categories",

View File

@ -1719,18 +1719,18 @@ dependencies = [
[[package]]
name = "thiserror"
version = "1.0.29"
version = "1.0.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "602eca064b2d83369e2b2f34b09c70b605402801927c65c11071ac911d299b88"
checksum = "854babe52e4df1653706b98fcfc05843010039b406875930a70e4d9644e5c417"
dependencies = [
"thiserror-impl",
]
[[package]]
name = "thiserror-impl"
version = "1.0.29"
version = "1.0.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bad553cc2c78e8de258400763a647e80e6d1b31ee237275d756f6836d204494c"
checksum = "aa32fd3f627f367fe16f893e2597ae3c05020f8bba2666a4e6ea73d377e5714b"
dependencies = [
"proc-macro2",
"quote",
@ -1788,6 +1788,7 @@ dependencies = [
"serde",
"serde_json",
"spm_precompiled",
"thiserror",
"unicode-normalization-alignments",
"unicode-segmentation",
"unicode_categories",

View File

@ -33,7 +33,17 @@ impl<'a> Serialize for OrderedVocabIter<'a> {
where
S: Serializer,
{
let iter = (0u32..(self.vocab_r.len() as u32)).map(|i| (&self.vocab_r[&i], i));
// There could be holes so max + 1 is more correct than vocab_r.len()
let max = self.vocab_r.iter().map(|(key, _)| key).max().unwrap_or(&0) + 1;
let iter = (0..max).filter_map(|i| {
if let Some(token) = self.vocab_r.get(&i){
Some((token, i))
}else{
warn!("The OrderedVocab you are attempting to save contains a hole for index {}, your vocabulary could be corrupted !", i);
println!("The OrderedVocab you are attempting to save contains a hole for index {}, your vocabulary could be corrupted !", i);
None
}
});
serializer.collect_map(iter)
}
}
@ -194,4 +204,15 @@ mod tests {
let result = trainer.train(&mut model);
assert!(result.is_err());
}
#[test]
fn incomplete_ordered_vocab() {
let vocab_r: HashMap<u32, String> =
HashMap::from([(0, "Hi".to_string()), (2, "There".to_string())]);
let ordered = OrderedVocabIter::new(&vocab_r);
let serialized = serde_json::to_string(&ordered).unwrap();
assert_eq!(serialized, "{\"Hi\":0,\"There\":2}");
}
}

View File

@ -83,7 +83,7 @@ impl<'de> Visitor<'de> for WordLevelVisitor {
#[cfg(test)]
mod tests {
use super::*;
use crate::models::wordlevel::{Vocab, WordLevel, WordLevelBuilder};
#[test]
fn serde() {
@ -94,6 +94,22 @@ mod tests {
assert_eq!(serde_json::from_str::<WordLevel>(wl_s).unwrap(), wl);
}
#[test]
fn incomplete_vocab() {
let vocab: Vocab = [("<unk>".into(), 0), ("b".into(), 2)]
.iter()
.cloned()
.collect();
let wordlevel = WordLevelBuilder::default()
.vocab(vocab)
.unk_token("<unk>".to_string())
.build()
.unwrap();
let wl_s = r#"{"type":"WordLevel","vocab":{"<unk>":0,"b":2},"unk_token":"<unk>"}"#;
assert_eq!(serde_json::to_string(&wordlevel).unwrap(), wl_s);
assert_eq!(serde_json::from_str::<WordLevel>(wl_s).unwrap(), wordlevel);
}
#[test]
fn deserialization_should_fail() {
let missing_unk = r#"{"type":"WordLevel","vocab":{}}"#;