Fixing issue with ConvBert not being able to save because of of holes in (#954)

the vocab.
This commit is contained in:
Nicolas Patry
2022-03-21 19:28:49 +01:00
committed by GitHub
parent 1bb9884f45
commit cd730594e9
4 changed files with 45 additions and 6 deletions

View File

@ -1743,6 +1743,7 @@ dependencies = [
"serde", "serde",
"serde_json", "serde_json",
"spm_precompiled", "spm_precompiled",
"thiserror",
"unicode-normalization-alignments", "unicode-normalization-alignments",
"unicode-segmentation", "unicode-segmentation",
"unicode_categories", "unicode_categories",

View File

@ -1719,18 +1719,18 @@ dependencies = [
[[package]] [[package]]
name = "thiserror" name = "thiserror"
version = "1.0.29" version = "1.0.30"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "602eca064b2d83369e2b2f34b09c70b605402801927c65c11071ac911d299b88" checksum = "854babe52e4df1653706b98fcfc05843010039b406875930a70e4d9644e5c417"
dependencies = [ dependencies = [
"thiserror-impl", "thiserror-impl",
] ]
[[package]] [[package]]
name = "thiserror-impl" name = "thiserror-impl"
version = "1.0.29" version = "1.0.30"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bad553cc2c78e8de258400763a647e80e6d1b31ee237275d756f6836d204494c" checksum = "aa32fd3f627f367fe16f893e2597ae3c05020f8bba2666a4e6ea73d377e5714b"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
@ -1788,6 +1788,7 @@ dependencies = [
"serde", "serde",
"serde_json", "serde_json",
"spm_precompiled", "spm_precompiled",
"thiserror",
"unicode-normalization-alignments", "unicode-normalization-alignments",
"unicode-segmentation", "unicode-segmentation",
"unicode_categories", "unicode_categories",

View File

@ -33,7 +33,17 @@ impl<'a> Serialize for OrderedVocabIter<'a> {
where where
S: Serializer, S: Serializer,
{ {
let iter = (0u32..(self.vocab_r.len() as u32)).map(|i| (&self.vocab_r[&i], i)); // There could be holes so max + 1 is more correct than vocab_r.len()
let max = self.vocab_r.iter().map(|(key, _)| key).max().unwrap_or(&0) + 1;
let iter = (0..max).filter_map(|i| {
if let Some(token) = self.vocab_r.get(&i){
Some((token, i))
}else{
warn!("The OrderedVocab you are attempting to save contains a hole for index {}, your vocabulary could be corrupted !", i);
println!("The OrderedVocab you are attempting to save contains a hole for index {}, your vocabulary could be corrupted !", i);
None
}
});
serializer.collect_map(iter) serializer.collect_map(iter)
} }
} }
@ -194,4 +204,15 @@ mod tests {
let result = trainer.train(&mut model); let result = trainer.train(&mut model);
assert!(result.is_err()); assert!(result.is_err());
} }
#[test]
fn incomplete_ordered_vocab() {
let vocab_r: HashMap<u32, String> =
HashMap::from([(0, "Hi".to_string()), (2, "There".to_string())]);
let ordered = OrderedVocabIter::new(&vocab_r);
let serialized = serde_json::to_string(&ordered).unwrap();
assert_eq!(serialized, "{\"Hi\":0,\"There\":2}");
}
} }

View File

@ -83,7 +83,7 @@ impl<'de> Visitor<'de> for WordLevelVisitor {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use crate::models::wordlevel::{Vocab, WordLevel, WordLevelBuilder};
#[test] #[test]
fn serde() { fn serde() {
@ -94,6 +94,22 @@ mod tests {
assert_eq!(serde_json::from_str::<WordLevel>(wl_s).unwrap(), wl); assert_eq!(serde_json::from_str::<WordLevel>(wl_s).unwrap(), wl);
} }
#[test]
fn incomplete_vocab() {
let vocab: Vocab = [("<unk>".into(), 0), ("b".into(), 2)]
.iter()
.cloned()
.collect();
let wordlevel = WordLevelBuilder::default()
.vocab(vocab)
.unk_token("<unk>".to_string())
.build()
.unwrap();
let wl_s = r#"{"type":"WordLevel","vocab":{"<unk>":0,"b":2},"unk_token":"<unk>"}"#;
assert_eq!(serde_json::to_string(&wordlevel).unwrap(), wl_s);
assert_eq!(serde_json::from_str::<WordLevel>(wl_s).unwrap(), wordlevel);
}
#[test] #[test]
fn deserialization_should_fail() { fn deserialization_should_fail() {
let missing_unk = r#"{"type":"WordLevel","vocab":{}}"#; let missing_unk = r#"{"type":"WordLevel","vocab":{}}"#;