mirror of
https://github.com/mii443/tokenizers.git
synced 2025-12-07 21:28:19 +00:00
Merge pull request #95 from huggingface/vocab-serialization
save BPE vocab in order of ID
This commit is contained in:
1
bindings/python/Cargo.lock
generated
1
bindings/python/Cargo.lock
generated
@@ -564,6 +564,7 @@ dependencies = [
|
|||||||
"rayon 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
"rayon 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
"regex 1.3.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
"regex 1.3.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
"regex-syntax 0.6.12 (registry+https://github.com/rust-lang/crates.io-index)",
|
"regex-syntax 0.6.12 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"serde 1.0.104 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
"serde_json 1.0.44 (registry+https://github.com/rust-lang/crates.io-index)",
|
"serde_json 1.0.44 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
"unicode-normalization-alignments 0.1.12 (registry+https://github.com/rust-lang/crates.io-index)",
|
"unicode-normalization-alignments 0.1.12 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
"unicode_categories 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
"unicode_categories 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ rand = "0.7.2"
|
|||||||
regex = "1.3.1"
|
regex = "1.3.1"
|
||||||
regex-syntax = "0.6.12"
|
regex-syntax = "0.6.12"
|
||||||
rayon = "1.2.0"
|
rayon = "1.2.0"
|
||||||
|
serde = "1.0"
|
||||||
serde_json = "1.0"
|
serde_json = "1.0"
|
||||||
clap = "2.33.0"
|
clap = "2.33.0"
|
||||||
unicode-normalization-alignments = "0.1.12"
|
unicode-normalization-alignments = "0.1.12"
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
use super::{Cache, Error, Pair, WithFirstLastIterator, Word, DEFAULT_CACHE_CAPACITY};
|
use super::{Cache, Error, Pair, WithFirstLastIterator, Word, DEFAULT_CACHE_CAPACITY};
|
||||||
use crate::tokenizer::{Model, Offsets, Result, Token};
|
use crate::tokenizer::{Model, Offsets, Result, Token};
|
||||||
use rand::{thread_rng, Rng};
|
use rand::{thread_rng, Rng};
|
||||||
|
use serde::{Serialize, Serializer};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use std::{
|
use std::{
|
||||||
collections::HashMap,
|
collections::HashMap,
|
||||||
@@ -426,7 +427,8 @@ impl Model for BPE {
|
|||||||
.iter()
|
.iter()
|
||||||
.collect();
|
.collect();
|
||||||
let mut vocab_file = File::create(&vocab_path)?;
|
let mut vocab_file = File::create(&vocab_path)?;
|
||||||
let serialized = serde_json::to_string(&self.vocab)?;
|
let order_vocab_iter = OrderedVocabIter::new(&self.vocab_r);
|
||||||
|
let serialized = serde_json::to_string(&order_vocab_iter)?;
|
||||||
vocab_file.write_all(&serialized.as_bytes())?;
|
vocab_file.write_all(&serialized.as_bytes())?;
|
||||||
|
|
||||||
// Write merges.txt
|
// Write merges.txt
|
||||||
@@ -455,11 +457,49 @@ impl Model for BPE {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Wraps a vocab mapping (ID -> token) to a struct that will be serialized in order
|
||||||
|
/// of token ID, smallest to largest.
|
||||||
|
struct OrderedVocabIter<'a> {
|
||||||
|
vocab_r: &'a HashMap<u32, String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> OrderedVocabIter<'a> {
|
||||||
|
fn new(vocab_r: &'a HashMap<u32, String>) -> Self {
|
||||||
|
Self { vocab_r }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> Serialize for OrderedVocabIter<'a> {
|
||||||
|
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
|
||||||
|
where
|
||||||
|
S: Serializer,
|
||||||
|
{
|
||||||
|
let iter = (0u32..(self.vocab_r.len() as u32)).map(|i| (&self.vocab_r[&i], i));
|
||||||
|
serializer.collect_map(iter)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use tempfile::NamedTempFile;
|
use tempfile::NamedTempFile;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_ordered_vocab_iter() {
|
||||||
|
let vocab_r: HashMap<u32, String> = [
|
||||||
|
(0, "a".into()),
|
||||||
|
(1, "b".into()),
|
||||||
|
(2, "c".into()),
|
||||||
|
(3, "ab".into()),
|
||||||
|
]
|
||||||
|
.iter()
|
||||||
|
.cloned()
|
||||||
|
.collect();
|
||||||
|
let order_vocab_iter = OrderedVocabIter::new(&vocab_r);
|
||||||
|
let serialized = serde_json::to_string(&order_vocab_iter).unwrap();
|
||||||
|
assert_eq!(serialized, "{\"a\":0,\"b\":1,\"c\":2,\"ab\":3}");
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
// Test tokenization. With dropout set to 0 tokenization is deterministic,
|
// Test tokenization. With dropout set to 0 tokenization is deterministic,
|
||||||
// so we know exactly what the result should be.
|
// so we know exactly what the result should be.
|
||||||
|
|||||||
Reference in New Issue
Block a user