mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-31 12:39:21 +00:00
Fixing issue with ConvBert not being able to save because of of holes in (#954)
the vocab.
This commit is contained in:
1
bindings/node/native/Cargo.lock
generated
1
bindings/node/native/Cargo.lock
generated
@ -1743,6 +1743,7 @@ dependencies = [
|
|||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"spm_precompiled",
|
"spm_precompiled",
|
||||||
|
"thiserror",
|
||||||
"unicode-normalization-alignments",
|
"unicode-normalization-alignments",
|
||||||
"unicode-segmentation",
|
"unicode-segmentation",
|
||||||
"unicode_categories",
|
"unicode_categories",
|
||||||
|
9
bindings/python/Cargo.lock
generated
9
bindings/python/Cargo.lock
generated
@ -1719,18 +1719,18 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "thiserror"
|
name = "thiserror"
|
||||||
version = "1.0.29"
|
version = "1.0.30"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "602eca064b2d83369e2b2f34b09c70b605402801927c65c11071ac911d299b88"
|
checksum = "854babe52e4df1653706b98fcfc05843010039b406875930a70e4d9644e5c417"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"thiserror-impl",
|
"thiserror-impl",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "thiserror-impl"
|
name = "thiserror-impl"
|
||||||
version = "1.0.29"
|
version = "1.0.30"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "bad553cc2c78e8de258400763a647e80e6d1b31ee237275d756f6836d204494c"
|
checksum = "aa32fd3f627f367fe16f893e2597ae3c05020f8bba2666a4e6ea73d377e5714b"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
@ -1788,6 +1788,7 @@ dependencies = [
|
|||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"spm_precompiled",
|
"spm_precompiled",
|
||||||
|
"thiserror",
|
||||||
"unicode-normalization-alignments",
|
"unicode-normalization-alignments",
|
||||||
"unicode-segmentation",
|
"unicode-segmentation",
|
||||||
"unicode_categories",
|
"unicode_categories",
|
||||||
|
@ -33,7 +33,17 @@ impl<'a> Serialize for OrderedVocabIter<'a> {
|
|||||||
where
|
where
|
||||||
S: Serializer,
|
S: Serializer,
|
||||||
{
|
{
|
||||||
let iter = (0u32..(self.vocab_r.len() as u32)).map(|i| (&self.vocab_r[&i], i));
|
// There could be holes so max + 1 is more correct than vocab_r.len()
|
||||||
|
let max = self.vocab_r.iter().map(|(key, _)| key).max().unwrap_or(&0) + 1;
|
||||||
|
let iter = (0..max).filter_map(|i| {
|
||||||
|
if let Some(token) = self.vocab_r.get(&i){
|
||||||
|
Some((token, i))
|
||||||
|
}else{
|
||||||
|
warn!("The OrderedVocab you are attempting to save contains a hole for index {}, your vocabulary could be corrupted !", i);
|
||||||
|
println!("The OrderedVocab you are attempting to save contains a hole for index {}, your vocabulary could be corrupted !", i);
|
||||||
|
None
|
||||||
|
}
|
||||||
|
});
|
||||||
serializer.collect_map(iter)
|
serializer.collect_map(iter)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -194,4 +204,15 @@ mod tests {
|
|||||||
let result = trainer.train(&mut model);
|
let result = trainer.train(&mut model);
|
||||||
assert!(result.is_err());
|
assert!(result.is_err());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn incomplete_ordered_vocab() {
|
||||||
|
let vocab_r: HashMap<u32, String> =
|
||||||
|
HashMap::from([(0, "Hi".to_string()), (2, "There".to_string())]);
|
||||||
|
|
||||||
|
let ordered = OrderedVocabIter::new(&vocab_r);
|
||||||
|
|
||||||
|
let serialized = serde_json::to_string(&ordered).unwrap();
|
||||||
|
assert_eq!(serialized, "{\"Hi\":0,\"There\":2}");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -83,7 +83,7 @@ impl<'de> Visitor<'de> for WordLevelVisitor {
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use crate::models::wordlevel::{Vocab, WordLevel, WordLevelBuilder};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn serde() {
|
fn serde() {
|
||||||
@ -94,6 +94,22 @@ mod tests {
|
|||||||
assert_eq!(serde_json::from_str::<WordLevel>(wl_s).unwrap(), wl);
|
assert_eq!(serde_json::from_str::<WordLevel>(wl_s).unwrap(), wl);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn incomplete_vocab() {
|
||||||
|
let vocab: Vocab = [("<unk>".into(), 0), ("b".into(), 2)]
|
||||||
|
.iter()
|
||||||
|
.cloned()
|
||||||
|
.collect();
|
||||||
|
let wordlevel = WordLevelBuilder::default()
|
||||||
|
.vocab(vocab)
|
||||||
|
.unk_token("<unk>".to_string())
|
||||||
|
.build()
|
||||||
|
.unwrap();
|
||||||
|
let wl_s = r#"{"type":"WordLevel","vocab":{"<unk>":0,"b":2},"unk_token":"<unk>"}"#;
|
||||||
|
assert_eq!(serde_json::to_string(&wordlevel).unwrap(), wl_s);
|
||||||
|
assert_eq!(serde_json::from_str::<WordLevel>(wl_s).unwrap(), wordlevel);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn deserialization_should_fail() {
|
fn deserialization_should_fail() {
|
||||||
let missing_unk = r#"{"type":"WordLevel","vocab":{}}"#;
|
let missing_unk = r#"{"type":"WordLevel","vocab":{}}"#;
|
||||||
|
Reference in New Issue
Block a user