mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-23 16:49:27 +00:00
Fixing Trainer with u8 instead of chars. (#452)
* Fixing Trainer with u8 instead of chars. Now check both optimized and unoptimized encodings schemes for Unigram. * Small fixes. * Fixing makefile.
This commit is contained in:
@ -7,10 +7,9 @@ from ..utils import data_dir, train_files
|
||||
|
||||
|
||||
class TestUnigram:
|
||||
@pytest.mark.slow
|
||||
def test_train(self, train_files):
|
||||
tokenizer = SentencePieceUnigramTokenizer()
|
||||
tokenizer.train(train_files["big"], show_progress=False)
|
||||
tokenizer.train(train_files["small"], show_progress=False)
|
||||
|
||||
filename = "tests/data/unigram_trained.json"
|
||||
tokenizer.save(filename)
|
||||
|
@ -61,8 +61,17 @@ def openai_files(data_dir):
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def train_files(data_dir):
|
||||
big = download("https://norvig.com/big.txt")
|
||||
small = os.path.join(DATA_PATH, "small.txt")
|
||||
with open(small, "w") as f:
|
||||
with open(big, "r") as g:
|
||||
for i, line in enumerate(g):
|
||||
f.write(line)
|
||||
if i > 100:
|
||||
break
|
||||
return {
|
||||
"big": download("https://norvig.com/big.txt"),
|
||||
"small": small,
|
||||
"big": big,
|
||||
}
|
||||
|
||||
|
||||
|
@ -4,8 +4,8 @@ TESTS_DIR = tests
|
||||
|
||||
dir_guard=@mkdir -p $(@D)
|
||||
|
||||
SHARED_RESOURCES = $(DATA_DIR)/gpt2-vocab.json $(DATA_DIR)/gpt2-merges.txt $(DATA_DIR)/bert-base-uncased-vocab.txt
|
||||
BENCHMARK_RESOURCES = $(SHARED_RESOURCES) $(DATA_DIR)/big.txt $(DATA_DIR)/small.txt
|
||||
SHARED_RESOURCES = $(DATA_DIR)/gpt2-vocab.json $(DATA_DIR)/gpt2-merges.txt $(DATA_DIR)/bert-base-uncased-vocab.txt $(DATA_DIR)/big.txt $(DATA_DIR)/small.txt
|
||||
BENCHMARK_RESOURCES = $(SHARED_RESOURCES)
|
||||
TESTS_RESOURCES = $(SHARED_RESOURCES) $(DATA_DIR)/unigram.json $(DATA_DIR)/unigram_wagahaiwa_nekodearu.txt
|
||||
|
||||
.PHONY : build
|
||||
|
@ -137,7 +137,7 @@ fn log_sum_exp(x: f64, y: f64, init_mode: bool) -> f64 {
|
||||
|
||||
impl<'a> Lattice<'a> {
|
||||
pub fn from(sentence: &'a str, unk_id: usize, bos_id: usize, eos_id: usize) -> Lattice<'a> {
|
||||
let len = sentence.chars().count();
|
||||
let len = sentence.bytes().count();
|
||||
let k_reserved_node_size = 16;
|
||||
// We are adding 2 tokens, bos and eos
|
||||
let mut nodes: Vec<NodeRef> = Vec::with_capacity(k_reserved_node_size);
|
||||
@ -177,7 +177,8 @@ impl<'a> Lattice<'a> {
|
||||
|
||||
pub fn viterbi(&mut self) -> Vec<NodeRef> {
|
||||
let len = self.len;
|
||||
for pos in 0..=len {
|
||||
let mut pos = 0;
|
||||
while pos <= len {
|
||||
if self.begin_nodes[pos].is_empty() {
|
||||
return vec![];
|
||||
}
|
||||
@ -201,6 +202,11 @@ impl<'a> Lattice<'a> {
|
||||
None => return vec![],
|
||||
}
|
||||
}
|
||||
if let Some(c) = self.sentence[pos..].chars().next() {
|
||||
pos += c.len_utf8();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let mut results: Vec<NodeRef> = vec![];
|
||||
@ -220,11 +226,7 @@ impl<'a> Lattice<'a> {
|
||||
}
|
||||
|
||||
pub fn piece(&self, node: &Node) -> String {
|
||||
self.sentence
|
||||
.chars()
|
||||
.skip(node.pos)
|
||||
.take(node.length)
|
||||
.collect()
|
||||
self.sentence[node.pos..node.pos + node.length].to_owned()
|
||||
}
|
||||
|
||||
pub fn tokens(&mut self) -> Vec<String> {
|
||||
@ -306,14 +308,11 @@ impl<'a> Lattice<'a> {
|
||||
pub fn len(&self) -> usize {
|
||||
self.len
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.len == 0
|
||||
}
|
||||
|
||||
pub fn utf8_len(&self) -> usize {
|
||||
self.sentence.len()
|
||||
}
|
||||
|
||||
pub fn bos_node(&self) -> NodeRef {
|
||||
Rc::clone(&self.end_nodes[0][0])
|
||||
}
|
||||
@ -443,18 +442,14 @@ mod tests {
|
||||
let lattice = Lattice::from("", 0, 1, 2);
|
||||
|
||||
assert_eq!(lattice.len(), 0);
|
||||
assert_eq!(lattice.utf8_len(), 0);
|
||||
// EXPECT_EQ(0, lattice.utf8_size());
|
||||
|
||||
let lattice = Lattice::from("", 0, 1, 2);
|
||||
assert_eq!(lattice.len(), 0);
|
||||
assert_eq!(lattice.utf8_len(), 0);
|
||||
assert_eq!(lattice.sentence(), "");
|
||||
assert_eq!(lattice.surface(0), "");
|
||||
|
||||
let lattice = Lattice::from("test", 0, 1, 2);
|
||||
assert_eq!(lattice.len(), 4);
|
||||
assert_eq!(lattice.utf8_len(), 4);
|
||||
assert_eq!(lattice.sentence(), "test");
|
||||
assert_eq!(lattice.surface(0), "test");
|
||||
assert_eq!(lattice.surface(1), "est");
|
||||
@ -476,8 +471,7 @@ mod tests {
|
||||
);
|
||||
|
||||
let lattice = Lattice::from("テストab", 0, 1, 2);
|
||||
assert_eq!(lattice.len(), 5);
|
||||
assert_eq!(lattice.utf8_len(), 11);
|
||||
assert_eq!(lattice.len(), 11);
|
||||
assert_eq!(lattice.sentence(), "テストab");
|
||||
assert_eq!(lattice.surface(0), "テストab");
|
||||
assert_eq!(lattice.surface(1), "ストab");
|
||||
@ -492,11 +486,11 @@ mod tests {
|
||||
|
||||
lattice.insert(0, 1, 0.0, 3);
|
||||
lattice.insert(1, 1, 0.0, 4);
|
||||
lattice.insert(2, 1, 0.0, 5);
|
||||
lattice.insert(3, 1, 0.0, 6);
|
||||
lattice.insert(2, 3, 0.0, 5);
|
||||
lattice.insert(5, 3, 0.0, 6);
|
||||
lattice.insert(0, 2, 0.0, 7);
|
||||
lattice.insert(1, 2, 0.0, 8);
|
||||
lattice.insert(2, 2, 0.0, 9);
|
||||
lattice.insert(1, 4, 0.0, 8);
|
||||
lattice.insert(2, 6, 0.0, 9);
|
||||
// 0 & 1 are bos and eos
|
||||
let node0 = lattice.nodes[2].borrow();
|
||||
let node1 = lattice.nodes[3].borrow();
|
||||
@ -517,18 +511,18 @@ mod tests {
|
||||
assert_eq!(node0.pos, 0);
|
||||
assert_eq!(node1.pos, 1);
|
||||
assert_eq!(node2.pos, 2);
|
||||
assert_eq!(node3.pos, 3);
|
||||
assert_eq!(node3.pos, 5);
|
||||
assert_eq!(node4.pos, 0);
|
||||
assert_eq!(node5.pos, 1);
|
||||
assert_eq!(node6.pos, 2);
|
||||
|
||||
assert_eq!(node0.length, 1);
|
||||
assert_eq!(node1.length, 1);
|
||||
assert_eq!(node2.length, 1);
|
||||
assert_eq!(node3.length, 1);
|
||||
assert_eq!(node2.length, 3);
|
||||
assert_eq!(node3.length, 3);
|
||||
assert_eq!(node4.length, 2);
|
||||
assert_eq!(node5.length, 2);
|
||||
assert_eq!(node6.length, 2);
|
||||
assert_eq!(node5.length, 4);
|
||||
assert_eq!(node6.length, 6);
|
||||
|
||||
assert_eq!(lattice.bos_node().borrow().id, 1);
|
||||
assert_eq!(lattice.eos_node().borrow().id, 2);
|
||||
@ -543,14 +537,14 @@ mod tests {
|
||||
assert_eq!(lattice.begin_nodes[0].len(), 2);
|
||||
assert_eq!(lattice.begin_nodes[1].len(), 2);
|
||||
assert_eq!(lattice.begin_nodes[2].len(), 2);
|
||||
assert_eq!(lattice.begin_nodes[3].len(), 1);
|
||||
assert_eq!(lattice.begin_nodes[4].len(), 1);
|
||||
assert_eq!(lattice.begin_nodes[5].len(), 1);
|
||||
assert_eq!(lattice.begin_nodes[8].len(), 1);
|
||||
|
||||
assert_eq!(lattice.end_nodes[0].len(), 1);
|
||||
assert_eq!(lattice.end_nodes[1].len(), 1);
|
||||
assert_eq!(lattice.end_nodes[2].len(), 2);
|
||||
assert_eq!(lattice.end_nodes[3].len(), 2);
|
||||
assert_eq!(lattice.end_nodes[4].len(), 2);
|
||||
assert_eq!(lattice.end_nodes[5].len(), 2);
|
||||
assert_eq!(lattice.end_nodes[8].len(), 2);
|
||||
|
||||
assert_eq!(lattice.begin_nodes[0][0].borrow().id, node0.id);
|
||||
assert_eq!(lattice.begin_nodes[0][1].borrow().id, node4.id);
|
||||
@ -558,10 +552,10 @@ mod tests {
|
||||
assert_eq!(lattice.begin_nodes[1][1].borrow().id, node5.id);
|
||||
assert_eq!(lattice.begin_nodes[2][0].borrow().id, node2.id);
|
||||
assert_eq!(lattice.begin_nodes[2][1].borrow().id, node6.id);
|
||||
assert_eq!(lattice.begin_nodes[3][0].borrow().id, node3.id);
|
||||
assert_eq!(lattice.begin_nodes[5][0].borrow().id, node3.id);
|
||||
assert_eq!(
|
||||
lattice.eos_node().borrow().id,
|
||||
lattice.begin_nodes[4][0].borrow().id
|
||||
lattice.begin_nodes[8][0].borrow().id
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
@ -571,10 +565,10 @@ mod tests {
|
||||
assert_eq!(node0.id, lattice.end_nodes[1][0].borrow().id);
|
||||
assert_eq!(node1.id, lattice.end_nodes[2][0].borrow().id);
|
||||
assert_eq!(node4.id, lattice.end_nodes[2][1].borrow().id);
|
||||
assert_eq!(node2.id, lattice.end_nodes[3][0].borrow().id);
|
||||
assert_eq!(node5.id, lattice.end_nodes[3][1].borrow().id);
|
||||
assert_eq!(node3.id, lattice.end_nodes[4][0].borrow().id);
|
||||
assert_eq!(node6.id, lattice.end_nodes[4][1].borrow().id);
|
||||
assert_eq!(node2.id, lattice.end_nodes[5][0].borrow().id);
|
||||
assert_eq!(node5.id, lattice.end_nodes[5][1].borrow().id);
|
||||
assert_eq!(node3.id, lattice.end_nodes[8][0].borrow().id);
|
||||
assert_eq!(node6.id, lattice.end_nodes[8][1].borrow().id);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -24,6 +24,7 @@ pub struct Unigram {
|
||||
pub(super) eos_id: usize,
|
||||
|
||||
fuse_unk: bool,
|
||||
is_optimized: bool,
|
||||
}
|
||||
impl PartialEq for Unigram {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
@ -46,6 +47,7 @@ impl Clone for Unigram {
|
||||
bos_id: self.bos_id,
|
||||
eos_id: self.eos_id,
|
||||
fuse_unk: self.fuse_unk,
|
||||
is_optimized: self.is_optimized,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -122,6 +124,7 @@ impl Unigram {
|
||||
}
|
||||
let trie = builder.build();
|
||||
let fuse_unk = true;
|
||||
let is_optimized = true;
|
||||
|
||||
Ok(Unigram {
|
||||
vocab,
|
||||
@ -133,6 +136,7 @@ impl Unigram {
|
||||
unk_id,
|
||||
fuse_unk,
|
||||
cache: Cache::default(),
|
||||
is_optimized,
|
||||
})
|
||||
}
|
||||
|
||||
@ -142,6 +146,11 @@ impl Unigram {
|
||||
self.cache = self.cache.fresh();
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(super) fn set_optimized(&mut self, is_optimized: bool) {
|
||||
self.is_optimized = is_optimized;
|
||||
}
|
||||
|
||||
pub(super) fn len(&self) -> usize {
|
||||
self.vocab.len()
|
||||
}
|
||||
@ -179,7 +188,7 @@ impl Unigram {
|
||||
}
|
||||
|
||||
if !has_single_node {
|
||||
lattice.insert(begin_pos, 1, unk_score, self.unk_id);
|
||||
lattice.insert(begin_pos, mblen, unk_score, self.unk_id);
|
||||
}
|
||||
begin_pos += mblen
|
||||
}
|
||||
@ -212,7 +221,11 @@ impl Unigram {
|
||||
if let Some(result) = self.cache.get(sentence) {
|
||||
result.to_vec()
|
||||
} else {
|
||||
let result = self.encode_optimized(sentence);
|
||||
let result = if self.is_optimized {
|
||||
self.encode_optimized(sentence)
|
||||
} else {
|
||||
self.encode_unoptimized(sentence)
|
||||
};
|
||||
self.cache.set(sentence.to_owned(), result.clone());
|
||||
result
|
||||
}
|
||||
@ -327,7 +340,6 @@ impl Unigram {
|
||||
results
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn encode_unoptimized(&self, sentence: &str) -> Vec<String> {
|
||||
let mut lattice = Lattice::from(sentence, self.unk_id, self.bos_id, self.eos_id);
|
||||
self.populate_nodes(&mut lattice);
|
||||
@ -541,12 +553,17 @@ mod tests {
|
||||
];
|
||||
|
||||
let mut model = Unigram::from(sentencepieces, 0).unwrap();
|
||||
|
||||
for is_optimized in &[true, false] {
|
||||
model.set_optimized(*is_optimized);
|
||||
println!("IsOptimized {:?}", is_optimized);
|
||||
assert_eq!(model.encode("abc"), vec!["abc"]);
|
||||
assert_eq!(model.encode("AB"), vec!["AB"]);
|
||||
|
||||
model.set_fuse_unk(false);
|
||||
assert_eq!(model.encode("AB"), vec!["A", "B"]);
|
||||
model.set_fuse_unk(true);
|
||||
assert_eq!(model.encode("AB"), vec!["AB"]);
|
||||
|
||||
assert_eq!(model.encode("abcd"), vec!["ab", "cd"]);
|
||||
assert_eq!(model.encode("abcc"), vec!["abc", "c"]);
|
||||
@ -557,6 +574,7 @@ mod tests {
|
||||
model.set_fuse_unk(false);
|
||||
assert_eq!(model.encode("xyz東京"), vec!["x", "y", "z", "東", "京"]);
|
||||
model.set_fuse_unk(true);
|
||||
assert_eq!(model.encode("xyz東京"), vec!["xyz東京"]);
|
||||
|
||||
// User encoded in original version
|
||||
assert_eq!(model.encode("ABC"), vec!["ABC"]);
|
||||
@ -565,3 +583,4 @@ mod tests {
|
||||
assert_eq!(model.encode("abqrcd"), vec!["ab", "q", "r", "cd"]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,14 +1,13 @@
|
||||
#[cfg(not(debug_assertions))]
|
||||
use assert_approx_eq::assert_approx_eq;
|
||||
#[cfg(not(debug_assertions))]
|
||||
use std::collections::HashMap;
|
||||
#[cfg(not(debug_assertions))]
|
||||
use std::fs::read_to_string;
|
||||
use std::path::Path;
|
||||
#[cfg(not(debug_assertions))]
|
||||
use tokenizers::models::unigram::Lattice;
|
||||
use tokenizers::models::unigram::Unigram;
|
||||
use tokenizers::tokenizer::Model;
|
||||
use tokenizers::models::unigram::UnigramTrainer;
|
||||
use tokenizers::tokenizer::{Model, Trainer};
|
||||
|
||||
#[test]
|
||||
fn test_unigram_from_file() {
|
||||
@ -39,6 +38,26 @@ fn test_unigram_from_file() {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_train_unigram_from_file() {
|
||||
let content = read_to_string("data/small.txt").unwrap();
|
||||
let mut word_counts = HashMap::new();
|
||||
content.split_whitespace().for_each(|word| {
|
||||
// This is important for the test of char vs u8
|
||||
let word = format!("▁{}", word.to_string());
|
||||
*word_counts.entry(word).or_insert(0) += 1;
|
||||
});
|
||||
|
||||
// println!("Words counts {:?}", word_counts);
|
||||
|
||||
let trainer = UnigramTrainer::builder()
|
||||
.show_progress(false)
|
||||
.build()
|
||||
.unwrap();
|
||||
let (model, _) = trainer.train(word_counts).unwrap();
|
||||
assert_eq!(model.get_vocab_size(), 719);
|
||||
}
|
||||
|
||||
#[cfg(not(debug_assertions))]
|
||||
#[test]
|
||||
fn test_sample() {
|
||||
|
Reference in New Issue
Block a user