Merge pull request #167 from huggingface/hotfix_models_save

Hotfix models save
This commit is contained in:
MOI Anthony
2020-02-24 16:04:09 -05:00
committed by GitHub
4 changed files with 36 additions and 36 deletions

View File

@@ -1,7 +1,9 @@
use super::{Cache, Error, Pair, WithFirstLastIterator, Word, DEFAULT_CACHE_CAPACITY};
use super::{
super::OrderedVocabIter, Cache, Error, Pair, WithFirstLastIterator, Word,
DEFAULT_CACHE_CAPACITY,
};
use crate::tokenizer::{Model, Offsets, Result, Token};
use rand::{thread_rng, Rng};
use serde::{Serialize, Serializer};
use serde_json::Value;
use std::{
collections::HashMap,
@@ -463,28 +465,6 @@ impl Model for BPE {
}
}
/// Wraps a vocab mapping (ID -> token) to a struct that will be serialized in order
/// of token ID, smallest to largest.
struct OrderedVocabIter<'a> {
vocab_r: &'a HashMap<u32, String>,
}
impl<'a> OrderedVocabIter<'a> {
fn new(vocab_r: &'a HashMap<u32, String>) -> Self {
Self { vocab_r }
}
}
impl<'a> Serialize for OrderedVocabIter<'a> {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: Serializer,
{
let iter = (0u32..(self.vocab_r.len() as u32)).map(|i| (&self.vocab_r[&i], i));
serializer.collect_map(iter)
}
}
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -3,3 +3,28 @@
pub mod bpe;
pub mod wordlevel;
pub mod wordpiece;
use serde::{Serialize, Serializer};
use std::collections::HashMap;
/// Wraps a vocab mapping (ID -> token) to a struct that will be serialized in order
/// of token ID, smallest to largest.
struct OrderedVocabIter<'a> {
vocab_r: &'a HashMap<u32, String>,
}
impl<'a> OrderedVocabIter<'a> {
fn new(vocab_r: &'a HashMap<u32, String>) -> Self {
Self { vocab_r }
}
}
impl<'a> Serialize for OrderedVocabIter<'a> {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: Serializer,
{
let iter = (0u32..(self.vocab_r.len() as u32)).map(|i| (&self.vocab_r[&i], i));
serializer.collect_map(iter)
}
}

View File

@@ -1,3 +1,4 @@
use super::OrderedVocabIter;
use crate::tokenizer::{Model, Result, Token};
use serde_json::Value;
use std::collections::HashMap;
@@ -168,20 +169,14 @@ impl Model for WordLevel {
None => "vocab.json".to_string(),
};
// Write vocab.txt
// Write vocab.json
let vocab_path: PathBuf = [folder, Path::new(vocab_file_name.as_str())]
.iter()
.collect();
let mut vocab_file = File::create(&vocab_path)?;
let mut vocab: Vec<(&String, &u32)> = self.vocab.iter().collect();
vocab.sort_unstable_by_key(|k| *k.1);
vocab_file.write_all(
&vocab
.into_iter()
.map(|(token, _)| format!("{}\n", token).as_bytes().to_owned())
.flatten()
.collect::<Vec<_>>()[..],
)?;
let order_vocab_iter = OrderedVocabIter::new(&self.vocab_r);
let serialized = serde_json::to_string(&order_vocab_iter)?;
vocab_file.write_all(&serialized.as_bytes())?;
Ok(vec![vocab_path])
}

View File

@@ -252,8 +252,8 @@ impl Model for WordPiece {
fn save(&self, folder: &Path, name: Option<&str>) -> Result<Vec<PathBuf>> {
let vocab_file_name = match name {
Some(name) => format!("{}-vocab.json", name),
None => "vocab.json".to_string(),
Some(name) => format!("{}-vocab.txt", name),
None => "vocab.txt".to_string(),
};
// Write vocab.txt