diff --git a/bindings/python/tests/bindings/test_trainers.py b/bindings/python/tests/bindings/test_trainers.py index 79f26b4e..8800b265 100644 --- a/bindings/python/tests/bindings/test_trainers.py +++ b/bindings/python/tests/bindings/test_trainers.py @@ -7,10 +7,9 @@ from ..utils import data_dir, train_files class TestUnigram: - @pytest.mark.slow def test_train(self, train_files): tokenizer = SentencePieceUnigramTokenizer() - tokenizer.train(train_files["big"], show_progress=False) + tokenizer.train(train_files["small"], show_progress=False) filename = "tests/data/unigram_trained.json" tokenizer.save(filename) diff --git a/bindings/python/tests/utils.py b/bindings/python/tests/utils.py index 62efaabf..45c2a440 100644 --- a/bindings/python/tests/utils.py +++ b/bindings/python/tests/utils.py @@ -61,8 +61,17 @@ def openai_files(data_dir): @pytest.fixture(scope="session") def train_files(data_dir): + big = download("https://norvig.com/big.txt") + small = os.path.join(DATA_PATH, "small.txt") + with open(small, "w") as f: + with open(big, "r") as g: + for i, line in enumerate(g): + f.write(line) + if i > 100: + break return { - "big": download("https://norvig.com/big.txt"), + "small": small, + "big": big, } diff --git a/tokenizers/Makefile b/tokenizers/Makefile index 16da3b81..16f434c2 100644 --- a/tokenizers/Makefile +++ b/tokenizers/Makefile @@ -4,8 +4,8 @@ TESTS_DIR = tests dir_guard=@mkdir -p $(@D) -SHARED_RESOURCES = $(DATA_DIR)/gpt2-vocab.json $(DATA_DIR)/gpt2-merges.txt $(DATA_DIR)/bert-base-uncased-vocab.txt -BENCHMARK_RESOURCES = $(SHARED_RESOURCES) $(DATA_DIR)/big.txt $(DATA_DIR)/small.txt +SHARED_RESOURCES = $(DATA_DIR)/gpt2-vocab.json $(DATA_DIR)/gpt2-merges.txt $(DATA_DIR)/bert-base-uncased-vocab.txt $(DATA_DIR)/big.txt $(DATA_DIR)/small.txt +BENCHMARK_RESOURCES = $(SHARED_RESOURCES) TESTS_RESOURCES = $(SHARED_RESOURCES) $(DATA_DIR)/unigram.json $(DATA_DIR)/unigram_wagahaiwa_nekodearu.txt .PHONY : build diff --git a/tokenizers/src/models/unigram/lattice.rs b/tokenizers/src/models/unigram/lattice.rs index 3afd2b71..a8dbe208 100644 --- a/tokenizers/src/models/unigram/lattice.rs +++ b/tokenizers/src/models/unigram/lattice.rs @@ -137,7 +137,7 @@ fn log_sum_exp(x: f64, y: f64, init_mode: bool) -> f64 { impl<'a> Lattice<'a> { pub fn from(sentence: &'a str, unk_id: usize, bos_id: usize, eos_id: usize) -> Lattice<'a> { - let len = sentence.chars().count(); + let len = sentence.bytes().count(); let k_reserved_node_size = 16; // We are adding 2 tokens, bos and eos let mut nodes: Vec = Vec::with_capacity(k_reserved_node_size); @@ -177,7 +177,8 @@ impl<'a> Lattice<'a> { pub fn viterbi(&mut self) -> Vec { let len = self.len; - for pos in 0..=len { + let mut pos = 0; + while pos <= len { if self.begin_nodes[pos].is_empty() { return vec![]; } @@ -201,6 +202,11 @@ impl<'a> Lattice<'a> { None => return vec![], } } + if let Some(c) = self.sentence[pos..].chars().next() { + pos += c.len_utf8(); + } else { + break; + } } let mut results: Vec = vec![]; @@ -220,11 +226,7 @@ impl<'a> Lattice<'a> { } pub fn piece(&self, node: &Node) -> String { - self.sentence - .chars() - .skip(node.pos) - .take(node.length) - .collect() + self.sentence[node.pos..node.pos + node.length].to_owned() } pub fn tokens(&mut self) -> Vec { @@ -306,14 +308,11 @@ impl<'a> Lattice<'a> { pub fn len(&self) -> usize { self.len } + pub fn is_empty(&self) -> bool { self.len == 0 } - pub fn utf8_len(&self) -> usize { - self.sentence.len() - } - pub fn bos_node(&self) -> NodeRef { Rc::clone(&self.end_nodes[0][0]) } @@ -443,18 +442,14 @@ mod tests { let lattice = Lattice::from("", 0, 1, 2); assert_eq!(lattice.len(), 0); - assert_eq!(lattice.utf8_len(), 0); - // EXPECT_EQ(0, lattice.utf8_size()); let lattice = Lattice::from("", 0, 1, 2); assert_eq!(lattice.len(), 0); - assert_eq!(lattice.utf8_len(), 0); assert_eq!(lattice.sentence(), ""); assert_eq!(lattice.surface(0), ""); let lattice = Lattice::from("test", 0, 1, 2); assert_eq!(lattice.len(), 4); - assert_eq!(lattice.utf8_len(), 4); assert_eq!(lattice.sentence(), "test"); assert_eq!(lattice.surface(0), "test"); assert_eq!(lattice.surface(1), "est"); @@ -476,8 +471,7 @@ mod tests { ); let lattice = Lattice::from("テストab", 0, 1, 2); - assert_eq!(lattice.len(), 5); - assert_eq!(lattice.utf8_len(), 11); + assert_eq!(lattice.len(), 11); assert_eq!(lattice.sentence(), "テストab"); assert_eq!(lattice.surface(0), "テストab"); assert_eq!(lattice.surface(1), "ストab"); @@ -492,11 +486,11 @@ mod tests { lattice.insert(0, 1, 0.0, 3); lattice.insert(1, 1, 0.0, 4); - lattice.insert(2, 1, 0.0, 5); - lattice.insert(3, 1, 0.0, 6); + lattice.insert(2, 3, 0.0, 5); + lattice.insert(5, 3, 0.0, 6); lattice.insert(0, 2, 0.0, 7); - lattice.insert(1, 2, 0.0, 8); - lattice.insert(2, 2, 0.0, 9); + lattice.insert(1, 4, 0.0, 8); + lattice.insert(2, 6, 0.0, 9); // 0 & 1 are bos and eos let node0 = lattice.nodes[2].borrow(); let node1 = lattice.nodes[3].borrow(); @@ -517,18 +511,18 @@ mod tests { assert_eq!(node0.pos, 0); assert_eq!(node1.pos, 1); assert_eq!(node2.pos, 2); - assert_eq!(node3.pos, 3); + assert_eq!(node3.pos, 5); assert_eq!(node4.pos, 0); assert_eq!(node5.pos, 1); assert_eq!(node6.pos, 2); assert_eq!(node0.length, 1); assert_eq!(node1.length, 1); - assert_eq!(node2.length, 1); - assert_eq!(node3.length, 1); + assert_eq!(node2.length, 3); + assert_eq!(node3.length, 3); assert_eq!(node4.length, 2); - assert_eq!(node5.length, 2); - assert_eq!(node6.length, 2); + assert_eq!(node5.length, 4); + assert_eq!(node6.length, 6); assert_eq!(lattice.bos_node().borrow().id, 1); assert_eq!(lattice.eos_node().borrow().id, 2); @@ -543,14 +537,14 @@ mod tests { assert_eq!(lattice.begin_nodes[0].len(), 2); assert_eq!(lattice.begin_nodes[1].len(), 2); assert_eq!(lattice.begin_nodes[2].len(), 2); - assert_eq!(lattice.begin_nodes[3].len(), 1); - assert_eq!(lattice.begin_nodes[4].len(), 1); + assert_eq!(lattice.begin_nodes[5].len(), 1); + assert_eq!(lattice.begin_nodes[8].len(), 1); assert_eq!(lattice.end_nodes[0].len(), 1); assert_eq!(lattice.end_nodes[1].len(), 1); assert_eq!(lattice.end_nodes[2].len(), 2); - assert_eq!(lattice.end_nodes[3].len(), 2); - assert_eq!(lattice.end_nodes[4].len(), 2); + assert_eq!(lattice.end_nodes[5].len(), 2); + assert_eq!(lattice.end_nodes[8].len(), 2); assert_eq!(lattice.begin_nodes[0][0].borrow().id, node0.id); assert_eq!(lattice.begin_nodes[0][1].borrow().id, node4.id); @@ -558,10 +552,10 @@ mod tests { assert_eq!(lattice.begin_nodes[1][1].borrow().id, node5.id); assert_eq!(lattice.begin_nodes[2][0].borrow().id, node2.id); assert_eq!(lattice.begin_nodes[2][1].borrow().id, node6.id); - assert_eq!(lattice.begin_nodes[3][0].borrow().id, node3.id); + assert_eq!(lattice.begin_nodes[5][0].borrow().id, node3.id); assert_eq!( lattice.eos_node().borrow().id, - lattice.begin_nodes[4][0].borrow().id + lattice.begin_nodes[8][0].borrow().id ); assert_eq!( @@ -571,10 +565,10 @@ mod tests { assert_eq!(node0.id, lattice.end_nodes[1][0].borrow().id); assert_eq!(node1.id, lattice.end_nodes[2][0].borrow().id); assert_eq!(node4.id, lattice.end_nodes[2][1].borrow().id); - assert_eq!(node2.id, lattice.end_nodes[3][0].borrow().id); - assert_eq!(node5.id, lattice.end_nodes[3][1].borrow().id); - assert_eq!(node3.id, lattice.end_nodes[4][0].borrow().id); - assert_eq!(node6.id, lattice.end_nodes[4][1].borrow().id); + assert_eq!(node2.id, lattice.end_nodes[5][0].borrow().id); + assert_eq!(node5.id, lattice.end_nodes[5][1].borrow().id); + assert_eq!(node3.id, lattice.end_nodes[8][0].borrow().id); + assert_eq!(node6.id, lattice.end_nodes[8][1].borrow().id); } #[test] diff --git a/tokenizers/src/models/unigram/model.rs b/tokenizers/src/models/unigram/model.rs index 3ee3958b..b9af7582 100644 --- a/tokenizers/src/models/unigram/model.rs +++ b/tokenizers/src/models/unigram/model.rs @@ -24,6 +24,7 @@ pub struct Unigram { pub(super) eos_id: usize, fuse_unk: bool, + is_optimized: bool, } impl PartialEq for Unigram { fn eq(&self, other: &Self) -> bool { @@ -46,6 +47,7 @@ impl Clone for Unigram { bos_id: self.bos_id, eos_id: self.eos_id, fuse_unk: self.fuse_unk, + is_optimized: self.is_optimized, } } } @@ -122,6 +124,7 @@ impl Unigram { } let trie = builder.build(); let fuse_unk = true; + let is_optimized = true; Ok(Unigram { vocab, @@ -133,6 +136,7 @@ impl Unigram { unk_id, fuse_unk, cache: Cache::default(), + is_optimized, }) } @@ -142,6 +146,11 @@ impl Unigram { self.cache = self.cache.fresh(); } + #[cfg(test)] + pub(super) fn set_optimized(&mut self, is_optimized: bool) { + self.is_optimized = is_optimized; + } + pub(super) fn len(&self) -> usize { self.vocab.len() } @@ -179,7 +188,7 @@ impl Unigram { } if !has_single_node { - lattice.insert(begin_pos, 1, unk_score, self.unk_id); + lattice.insert(begin_pos, mblen, unk_score, self.unk_id); } begin_pos += mblen } @@ -212,7 +221,11 @@ impl Unigram { if let Some(result) = self.cache.get(sentence) { result.to_vec() } else { - let result = self.encode_optimized(sentence); + let result = if self.is_optimized { + self.encode_optimized(sentence) + } else { + self.encode_unoptimized(sentence) + }; self.cache.set(sentence.to_owned(), result.clone()); result } @@ -327,7 +340,6 @@ impl Unigram { results } - #[allow(dead_code)] fn encode_unoptimized(&self, sentence: &str) -> Vec { let mut lattice = Lattice::from(sentence, self.unk_id, self.bos_id, self.eos_id); self.populate_nodes(&mut lattice); @@ -541,27 +553,34 @@ mod tests { ]; let mut model = Unigram::from(sentencepieces, 0).unwrap(); - assert_eq!(model.encode("abc"), vec!["abc"]); - assert_eq!(model.encode("AB"), vec!["AB"]); - model.set_fuse_unk(false); - assert_eq!(model.encode("AB"), vec!["A", "B"]); - model.set_fuse_unk(true); + for is_optimized in &[true, false] { + model.set_optimized(*is_optimized); + println!("IsOptimized {:?}", is_optimized); + assert_eq!(model.encode("abc"), vec!["abc"]); + assert_eq!(model.encode("AB"), vec!["AB"]); - assert_eq!(model.encode("abcd"), vec!["ab", "cd"]); - assert_eq!(model.encode("abcc"), vec!["abc", "c"]); - assert_eq!( - model.encode("xabcabaabcdd"), - vec!["x", "abc", "ab", "a", "ab", "cd", "d"] - ); - model.set_fuse_unk(false); - assert_eq!(model.encode("xyz東京"), vec!["x", "y", "z", "東", "京"]); - model.set_fuse_unk(true); + model.set_fuse_unk(false); + assert_eq!(model.encode("AB"), vec!["A", "B"]); + model.set_fuse_unk(true); + assert_eq!(model.encode("AB"), vec!["AB"]); - // User encoded in original version - assert_eq!(model.encode("ABC"), vec!["ABC"]); - assert_eq!(model.encode("abABCcd"), vec!["ab", "ABC", "cd"]); - assert_eq!(model.encode("ababcdabcdcd"), vec!["ab", "abcdabcd", "cd"]); - assert_eq!(model.encode("abqrcd"), vec!["ab", "q", "r", "cd"]); + assert_eq!(model.encode("abcd"), vec!["ab", "cd"]); + assert_eq!(model.encode("abcc"), vec!["abc", "c"]); + assert_eq!( + model.encode("xabcabaabcdd"), + vec!["x", "abc", "ab", "a", "ab", "cd", "d"] + ); + model.set_fuse_unk(false); + assert_eq!(model.encode("xyz東京"), vec!["x", "y", "z", "東", "京"]); + model.set_fuse_unk(true); + assert_eq!(model.encode("xyz東京"), vec!["xyz東京"]); + + // User encoded in original version + assert_eq!(model.encode("ABC"), vec!["ABC"]); + assert_eq!(model.encode("abABCcd"), vec!["ab", "ABC", "cd"]); + assert_eq!(model.encode("ababcdabcdcd"), vec!["ab", "abcdabcd", "cd"]); + assert_eq!(model.encode("abqrcd"), vec!["ab", "q", "r", "cd"]); + } } } diff --git a/tokenizers/tests/unigram.rs b/tokenizers/tests/unigram.rs index be0517ba..be81842c 100644 --- a/tokenizers/tests/unigram.rs +++ b/tokenizers/tests/unigram.rs @@ -1,14 +1,13 @@ #[cfg(not(debug_assertions))] use assert_approx_eq::assert_approx_eq; -#[cfg(not(debug_assertions))] use std::collections::HashMap; -#[cfg(not(debug_assertions))] use std::fs::read_to_string; use std::path::Path; #[cfg(not(debug_assertions))] use tokenizers::models::unigram::Lattice; use tokenizers::models::unigram::Unigram; -use tokenizers::tokenizer::Model; +use tokenizers::models::unigram::UnigramTrainer; +use tokenizers::tokenizer::{Model, Trainer}; #[test] fn test_unigram_from_file() { @@ -39,6 +38,26 @@ fn test_unigram_from_file() { ); } +#[test] +fn test_train_unigram_from_file() { + let content = read_to_string("data/small.txt").unwrap(); + let mut word_counts = HashMap::new(); + content.split_whitespace().for_each(|word| { + // This is important for the test of char vs u8 + let word = format!("▁{}", word.to_string()); + *word_counts.entry(word).or_insert(0) += 1; + }); + + // println!("Words counts {:?}", word_counts); + + let trainer = UnigramTrainer::builder() + .show_progress(false) + .build() + .unwrap(); + let (model, _) = trainer.train(word_counts).unwrap(); + assert_eq!(model.get_vocab_size(), 719); +} + #[cfg(not(debug_assertions))] #[test] fn test_sample() {