mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
save vocab in order of ID
This commit is contained in:
1
bindings/python/Cargo.lock
generated
1
bindings/python/Cargo.lock
generated
@ -564,6 +564,7 @@ dependencies = [
|
||||
"rayon 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"regex 1.3.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"regex-syntax 0.6.12 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"serde 1.0.104 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"serde_json 1.0.44 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"unicode-normalization-alignments 0.1.12 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"unicode_categories 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
|
@ -35,6 +35,7 @@ rand = "0.7.2"
|
||||
regex = "1.3.1"
|
||||
regex-syntax = "0.6.12"
|
||||
rayon = "1.2.0"
|
||||
serde = "1.0"
|
||||
serde_json = "1.0"
|
||||
clap = "2.33.0"
|
||||
unicode-normalization-alignments = "0.1.12"
|
||||
|
@ -1,6 +1,7 @@
|
||||
use super::{Cache, Error, Pair, WithFirstLastIterator, Word, DEFAULT_CACHE_CAPACITY};
|
||||
use crate::tokenizer::{Model, Offsets, Result, Token};
|
||||
use rand::{thread_rng, Rng};
|
||||
use serde::{Serialize, Serializer};
|
||||
use serde_json::Value;
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
@ -426,7 +427,8 @@ impl Model for BPE {
|
||||
.iter()
|
||||
.collect();
|
||||
let mut vocab_file = File::create(&vocab_path)?;
|
||||
let serialized = serde_json::to_string(&self.vocab)?;
|
||||
let order_vocab_iter = OrderedVocabIter::new(&self.vocab_r);
|
||||
let serialized = serde_json::to_string(&order_vocab_iter)?;
|
||||
vocab_file.write_all(&serialized.as_bytes())?;
|
||||
|
||||
// Write merges.txt
|
||||
@ -455,11 +457,49 @@ impl Model for BPE {
|
||||
}
|
||||
}
|
||||
|
||||
/// Wraps a vocab mapping (ID -> token) to a struct that will be serialized in order
|
||||
/// of token ID, smallest to largest.
|
||||
struct OrderedVocabIter<'a> {
|
||||
vocab_r: &'a HashMap<u32, String>,
|
||||
}
|
||||
|
||||
impl<'a> OrderedVocabIter<'a> {
|
||||
fn new(vocab_r: &'a HashMap<u32, String>) -> Self {
|
||||
Self { vocab_r }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Serialize for OrderedVocabIter<'a> {
|
||||
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
let iter = (0u32..(self.vocab_r.len() as u32)).map(|i| (&self.vocab_r[&i], i));
|
||||
serializer.collect_map(iter)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::NamedTempFile;
|
||||
|
||||
#[test]
|
||||
fn test_ordered_vocab_iter() {
|
||||
let vocab_r: HashMap<u32, String> = [
|
||||
(0, "a".into()),
|
||||
(1, "b".into()),
|
||||
(2, "c".into()),
|
||||
(3, "ab".into()),
|
||||
]
|
||||
.iter()
|
||||
.cloned()
|
||||
.collect();
|
||||
let order_vocab_iter = OrderedVocabIter::new(&vocab_r);
|
||||
let serialized = serde_json::to_string(&order_vocab_iter).unwrap();
|
||||
assert_eq!(serialized, "{\"a\":0,\"b\":1,\"c\":2,\"ab\":3}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
// Test tokenization. With dropout set to 0 tokenization is deterministic,
|
||||
// so we know exactly what the result should be.
|
||||
|
Reference in New Issue
Block a user