mirror of
https://github.com/mii443/tokenizers.git
synced 2025-12-18 06:19:14 +00:00
Merge pull request #167 from huggingface/hotfix_models_save
Hotfix models save
This commit is contained in:
@@ -1,7 +1,9 @@
|
||||
use super::{Cache, Error, Pair, WithFirstLastIterator, Word, DEFAULT_CACHE_CAPACITY};
|
||||
use super::{
|
||||
super::OrderedVocabIter, Cache, Error, Pair, WithFirstLastIterator, Word,
|
||||
DEFAULT_CACHE_CAPACITY,
|
||||
};
|
||||
use crate::tokenizer::{Model, Offsets, Result, Token};
|
||||
use rand::{thread_rng, Rng};
|
||||
use serde::{Serialize, Serializer};
|
||||
use serde_json::Value;
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
@@ -463,28 +465,6 @@ impl Model for BPE {
|
||||
}
|
||||
}
|
||||
|
||||
/// Wraps a vocab mapping (ID -> token) to a struct that will be serialized in order
|
||||
/// of token ID, smallest to largest.
|
||||
struct OrderedVocabIter<'a> {
|
||||
vocab_r: &'a HashMap<u32, String>,
|
||||
}
|
||||
|
||||
impl<'a> OrderedVocabIter<'a> {
|
||||
fn new(vocab_r: &'a HashMap<u32, String>) -> Self {
|
||||
Self { vocab_r }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Serialize for OrderedVocabIter<'a> {
|
||||
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
let iter = (0u32..(self.vocab_r.len() as u32)).map(|i| (&self.vocab_r[&i], i));
|
||||
serializer.collect_map(iter)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
@@ -3,3 +3,28 @@
|
||||
pub mod bpe;
|
||||
pub mod wordlevel;
|
||||
pub mod wordpiece;
|
||||
|
||||
use serde::{Serialize, Serializer};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Wraps a vocab mapping (ID -> token) to a struct that will be serialized in order
|
||||
/// of token ID, smallest to largest.
|
||||
struct OrderedVocabIter<'a> {
|
||||
vocab_r: &'a HashMap<u32, String>,
|
||||
}
|
||||
|
||||
impl<'a> OrderedVocabIter<'a> {
|
||||
fn new(vocab_r: &'a HashMap<u32, String>) -> Self {
|
||||
Self { vocab_r }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Serialize for OrderedVocabIter<'a> {
|
||||
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
let iter = (0u32..(self.vocab_r.len() as u32)).map(|i| (&self.vocab_r[&i], i));
|
||||
serializer.collect_map(iter)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use super::OrderedVocabIter;
|
||||
use crate::tokenizer::{Model, Result, Token};
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
@@ -168,20 +169,14 @@ impl Model for WordLevel {
|
||||
None => "vocab.json".to_string(),
|
||||
};
|
||||
|
||||
// Write vocab.txt
|
||||
// Write vocab.json
|
||||
let vocab_path: PathBuf = [folder, Path::new(vocab_file_name.as_str())]
|
||||
.iter()
|
||||
.collect();
|
||||
let mut vocab_file = File::create(&vocab_path)?;
|
||||
let mut vocab: Vec<(&String, &u32)> = self.vocab.iter().collect();
|
||||
vocab.sort_unstable_by_key(|k| *k.1);
|
||||
vocab_file.write_all(
|
||||
&vocab
|
||||
.into_iter()
|
||||
.map(|(token, _)| format!("{}\n", token).as_bytes().to_owned())
|
||||
.flatten()
|
||||
.collect::<Vec<_>>()[..],
|
||||
)?;
|
||||
let order_vocab_iter = OrderedVocabIter::new(&self.vocab_r);
|
||||
let serialized = serde_json::to_string(&order_vocab_iter)?;
|
||||
vocab_file.write_all(&serialized.as_bytes())?;
|
||||
|
||||
Ok(vec![vocab_path])
|
||||
}
|
||||
|
||||
@@ -252,8 +252,8 @@ impl Model for WordPiece {
|
||||
|
||||
fn save(&self, folder: &Path, name: Option<&str>) -> Result<Vec<PathBuf>> {
|
||||
let vocab_file_name = match name {
|
||||
Some(name) => format!("{}-vocab.json", name),
|
||||
None => "vocab.json".to_string(),
|
||||
Some(name) => format!("{}-vocab.txt", name),
|
||||
None => "vocab.txt".to_string(),
|
||||
};
|
||||
|
||||
// Write vocab.txt
|
||||
|
||||
Reference in New Issue
Block a user