save vocab in order of ID

This commit is contained in:
epwalsh
2020-01-21 13:32:13 -08:00
parent da7e629e4a
commit 3a9badd2e0
3 changed files with 43 additions and 1 deletions

View File

@ -564,6 +564,7 @@ dependencies = [
"rayon 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
"regex 1.3.1 (registry+https://github.com/rust-lang/crates.io-index)",
"regex-syntax 0.6.12 (registry+https://github.com/rust-lang/crates.io-index)",
"serde 1.0.104 (registry+https://github.com/rust-lang/crates.io-index)",
"serde_json 1.0.44 (registry+https://github.com/rust-lang/crates.io-index)",
"unicode-normalization-alignments 0.1.12 (registry+https://github.com/rust-lang/crates.io-index)",
"unicode_categories 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)",

View File

@ -35,6 +35,7 @@ rand = "0.7.2"
regex = "1.3.1"
regex-syntax = "0.6.12"
rayon = "1.2.0"
serde = "1.0"
serde_json = "1.0"
clap = "2.33.0"
unicode-normalization-alignments = "0.1.12"

View File

@ -1,6 +1,7 @@
use super::{Cache, Error, Pair, WithFirstLastIterator, Word, DEFAULT_CACHE_CAPACITY};
use crate::tokenizer::{Model, Offsets, Result, Token};
use rand::{thread_rng, Rng};
use serde::{Serialize, Serializer};
use serde_json::Value;
use std::{
collections::HashMap,
@ -426,7 +427,8 @@ impl Model for BPE {
.iter()
.collect();
let mut vocab_file = File::create(&vocab_path)?;
let serialized = serde_json::to_string(&self.vocab)?;
let order_vocab_iter = OrderedVocabIter::new(&self.vocab_r);
let serialized = serde_json::to_string(&order_vocab_iter)?;
vocab_file.write_all(&serialized.as_bytes())?;
// Write merges.txt
@ -455,11 +457,49 @@ impl Model for BPE {
}
}
/// Wraps a vocab mapping (ID -> token) to a struct that will be serialized in order
/// of token ID, smallest to largest.
struct OrderedVocabIter<'a> {
vocab_r: &'a HashMap<u32, String>,
}
impl<'a> OrderedVocabIter<'a> {
fn new(vocab_r: &'a HashMap<u32, String>) -> Self {
Self { vocab_r }
}
}
impl<'a> Serialize for OrderedVocabIter<'a> {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: Serializer,
{
let iter = (0u32..(self.vocab_r.len() as u32)).map(|i| (&self.vocab_r[&i], i));
serializer.collect_map(iter)
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::NamedTempFile;
#[test]
fn test_ordered_vocab_iter() {
let vocab_r: HashMap<u32, String> = [
(0, "a".into()),
(1, "b".into()),
(2, "c".into()),
(3, "ab".into()),
]
.iter()
.cloned()
.collect();
let order_vocab_iter = OrderedVocabIter::new(&vocab_r);
let serialized = serde_json::to_string(&order_vocab_iter).unwrap();
assert_eq!(serialized, "{\"a\":0,\"b\":1,\"c\":2,\"ab\":3}");
}
#[test]
// Test tokenization. With dropout set to 0 tokenization is deterministic,
// so we know exactly what the result should be.