Fixing Trainer with u8 instead of chars. (#452)

* Fixing Trainer with u8 instead of chars.

Now check both optimized and unoptimized encodings schemes for Unigram.

* Small fixes.

* Fixing makefile.
This commit is contained in:
Nicolas Patry
2020-10-09 18:57:14 +02:00
committed by GitHub
parent 35feff0042
commit fbca797b3d
6 changed files with 106 additions and 66 deletions

View File

@ -7,10 +7,9 @@ from ..utils import data_dir, train_files
class TestUnigram:
@pytest.mark.slow
def test_train(self, train_files):
tokenizer = SentencePieceUnigramTokenizer()
tokenizer.train(train_files["big"], show_progress=False)
tokenizer.train(train_files["small"], show_progress=False)
filename = "tests/data/unigram_trained.json"
tokenizer.save(filename)

View File

@ -61,8 +61,17 @@ def openai_files(data_dir):
@pytest.fixture(scope="session")
def train_files(data_dir):
big = download("https://norvig.com/big.txt")
small = os.path.join(DATA_PATH, "small.txt")
with open(small, "w") as f:
with open(big, "r") as g:
for i, line in enumerate(g):
f.write(line)
if i > 100:
break
return {
"big": download("https://norvig.com/big.txt"),
"small": small,
"big": big,
}

View File

@ -4,8 +4,8 @@ TESTS_DIR = tests
dir_guard=@mkdir -p $(@D)
SHARED_RESOURCES = $(DATA_DIR)/gpt2-vocab.json $(DATA_DIR)/gpt2-merges.txt $(DATA_DIR)/bert-base-uncased-vocab.txt
BENCHMARK_RESOURCES = $(SHARED_RESOURCES) $(DATA_DIR)/big.txt $(DATA_DIR)/small.txt
SHARED_RESOURCES = $(DATA_DIR)/gpt2-vocab.json $(DATA_DIR)/gpt2-merges.txt $(DATA_DIR)/bert-base-uncased-vocab.txt $(DATA_DIR)/big.txt $(DATA_DIR)/small.txt
BENCHMARK_RESOURCES = $(SHARED_RESOURCES)
TESTS_RESOURCES = $(SHARED_RESOURCES) $(DATA_DIR)/unigram.json $(DATA_DIR)/unigram_wagahaiwa_nekodearu.txt
.PHONY : build

View File

@ -137,7 +137,7 @@ fn log_sum_exp(x: f64, y: f64, init_mode: bool) -> f64 {
impl<'a> Lattice<'a> {
pub fn from(sentence: &'a str, unk_id: usize, bos_id: usize, eos_id: usize) -> Lattice<'a> {
let len = sentence.chars().count();
let len = sentence.bytes().count();
let k_reserved_node_size = 16;
// We are adding 2 tokens, bos and eos
let mut nodes: Vec<NodeRef> = Vec::with_capacity(k_reserved_node_size);
@ -177,7 +177,8 @@ impl<'a> Lattice<'a> {
pub fn viterbi(&mut self) -> Vec<NodeRef> {
let len = self.len;
for pos in 0..=len {
let mut pos = 0;
while pos <= len {
if self.begin_nodes[pos].is_empty() {
return vec![];
}
@ -201,6 +202,11 @@ impl<'a> Lattice<'a> {
None => return vec![],
}
}
if let Some(c) = self.sentence[pos..].chars().next() {
pos += c.len_utf8();
} else {
break;
}
}
let mut results: Vec<NodeRef> = vec![];
@ -220,11 +226,7 @@ impl<'a> Lattice<'a> {
}
pub fn piece(&self, node: &Node) -> String {
self.sentence
.chars()
.skip(node.pos)
.take(node.length)
.collect()
self.sentence[node.pos..node.pos + node.length].to_owned()
}
pub fn tokens(&mut self) -> Vec<String> {
@ -306,14 +308,11 @@ impl<'a> Lattice<'a> {
pub fn len(&self) -> usize {
self.len
}
pub fn is_empty(&self) -> bool {
self.len == 0
}
pub fn utf8_len(&self) -> usize {
self.sentence.len()
}
pub fn bos_node(&self) -> NodeRef {
Rc::clone(&self.end_nodes[0][0])
}
@ -443,18 +442,14 @@ mod tests {
let lattice = Lattice::from("", 0, 1, 2);
assert_eq!(lattice.len(), 0);
assert_eq!(lattice.utf8_len(), 0);
// EXPECT_EQ(0, lattice.utf8_size());
let lattice = Lattice::from("", 0, 1, 2);
assert_eq!(lattice.len(), 0);
assert_eq!(lattice.utf8_len(), 0);
assert_eq!(lattice.sentence(), "");
assert_eq!(lattice.surface(0), "");
let lattice = Lattice::from("test", 0, 1, 2);
assert_eq!(lattice.len(), 4);
assert_eq!(lattice.utf8_len(), 4);
assert_eq!(lattice.sentence(), "test");
assert_eq!(lattice.surface(0), "test");
assert_eq!(lattice.surface(1), "est");
@ -476,8 +471,7 @@ mod tests {
);
let lattice = Lattice::from("テストab", 0, 1, 2);
assert_eq!(lattice.len(), 5);
assert_eq!(lattice.utf8_len(), 11);
assert_eq!(lattice.len(), 11);
assert_eq!(lattice.sentence(), "テストab");
assert_eq!(lattice.surface(0), "テストab");
assert_eq!(lattice.surface(1), "ストab");
@ -492,11 +486,11 @@ mod tests {
lattice.insert(0, 1, 0.0, 3);
lattice.insert(1, 1, 0.0, 4);
lattice.insert(2, 1, 0.0, 5);
lattice.insert(3, 1, 0.0, 6);
lattice.insert(2, 3, 0.0, 5);
lattice.insert(5, 3, 0.0, 6);
lattice.insert(0, 2, 0.0, 7);
lattice.insert(1, 2, 0.0, 8);
lattice.insert(2, 2, 0.0, 9);
lattice.insert(1, 4, 0.0, 8);
lattice.insert(2, 6, 0.0, 9);
// 0 & 1 are bos and eos
let node0 = lattice.nodes[2].borrow();
let node1 = lattice.nodes[3].borrow();
@ -517,18 +511,18 @@ mod tests {
assert_eq!(node0.pos, 0);
assert_eq!(node1.pos, 1);
assert_eq!(node2.pos, 2);
assert_eq!(node3.pos, 3);
assert_eq!(node3.pos, 5);
assert_eq!(node4.pos, 0);
assert_eq!(node5.pos, 1);
assert_eq!(node6.pos, 2);
assert_eq!(node0.length, 1);
assert_eq!(node1.length, 1);
assert_eq!(node2.length, 1);
assert_eq!(node3.length, 1);
assert_eq!(node2.length, 3);
assert_eq!(node3.length, 3);
assert_eq!(node4.length, 2);
assert_eq!(node5.length, 2);
assert_eq!(node6.length, 2);
assert_eq!(node5.length, 4);
assert_eq!(node6.length, 6);
assert_eq!(lattice.bos_node().borrow().id, 1);
assert_eq!(lattice.eos_node().borrow().id, 2);
@ -543,14 +537,14 @@ mod tests {
assert_eq!(lattice.begin_nodes[0].len(), 2);
assert_eq!(lattice.begin_nodes[1].len(), 2);
assert_eq!(lattice.begin_nodes[2].len(), 2);
assert_eq!(lattice.begin_nodes[3].len(), 1);
assert_eq!(lattice.begin_nodes[4].len(), 1);
assert_eq!(lattice.begin_nodes[5].len(), 1);
assert_eq!(lattice.begin_nodes[8].len(), 1);
assert_eq!(lattice.end_nodes[0].len(), 1);
assert_eq!(lattice.end_nodes[1].len(), 1);
assert_eq!(lattice.end_nodes[2].len(), 2);
assert_eq!(lattice.end_nodes[3].len(), 2);
assert_eq!(lattice.end_nodes[4].len(), 2);
assert_eq!(lattice.end_nodes[5].len(), 2);
assert_eq!(lattice.end_nodes[8].len(), 2);
assert_eq!(lattice.begin_nodes[0][0].borrow().id, node0.id);
assert_eq!(lattice.begin_nodes[0][1].borrow().id, node4.id);
@ -558,10 +552,10 @@ mod tests {
assert_eq!(lattice.begin_nodes[1][1].borrow().id, node5.id);
assert_eq!(lattice.begin_nodes[2][0].borrow().id, node2.id);
assert_eq!(lattice.begin_nodes[2][1].borrow().id, node6.id);
assert_eq!(lattice.begin_nodes[3][0].borrow().id, node3.id);
assert_eq!(lattice.begin_nodes[5][0].borrow().id, node3.id);
assert_eq!(
lattice.eos_node().borrow().id,
lattice.begin_nodes[4][0].borrow().id
lattice.begin_nodes[8][0].borrow().id
);
assert_eq!(
@ -571,10 +565,10 @@ mod tests {
assert_eq!(node0.id, lattice.end_nodes[1][0].borrow().id);
assert_eq!(node1.id, lattice.end_nodes[2][0].borrow().id);
assert_eq!(node4.id, lattice.end_nodes[2][1].borrow().id);
assert_eq!(node2.id, lattice.end_nodes[3][0].borrow().id);
assert_eq!(node5.id, lattice.end_nodes[3][1].borrow().id);
assert_eq!(node3.id, lattice.end_nodes[4][0].borrow().id);
assert_eq!(node6.id, lattice.end_nodes[4][1].borrow().id);
assert_eq!(node2.id, lattice.end_nodes[5][0].borrow().id);
assert_eq!(node5.id, lattice.end_nodes[5][1].borrow().id);
assert_eq!(node3.id, lattice.end_nodes[8][0].borrow().id);
assert_eq!(node6.id, lattice.end_nodes[8][1].borrow().id);
}
#[test]

View File

@ -24,6 +24,7 @@ pub struct Unigram {
pub(super) eos_id: usize,
fuse_unk: bool,
is_optimized: bool,
}
impl PartialEq for Unigram {
fn eq(&self, other: &Self) -> bool {
@ -46,6 +47,7 @@ impl Clone for Unigram {
bos_id: self.bos_id,
eos_id: self.eos_id,
fuse_unk: self.fuse_unk,
is_optimized: self.is_optimized,
}
}
}
@ -122,6 +124,7 @@ impl Unigram {
}
let trie = builder.build();
let fuse_unk = true;
let is_optimized = true;
Ok(Unigram {
vocab,
@ -133,6 +136,7 @@ impl Unigram {
unk_id,
fuse_unk,
cache: Cache::default(),
is_optimized,
})
}
@ -142,6 +146,11 @@ impl Unigram {
self.cache = self.cache.fresh();
}
#[cfg(test)]
pub(super) fn set_optimized(&mut self, is_optimized: bool) {
self.is_optimized = is_optimized;
}
pub(super) fn len(&self) -> usize {
self.vocab.len()
}
@ -179,7 +188,7 @@ impl Unigram {
}
if !has_single_node {
lattice.insert(begin_pos, 1, unk_score, self.unk_id);
lattice.insert(begin_pos, mblen, unk_score, self.unk_id);
}
begin_pos += mblen
}
@ -212,7 +221,11 @@ impl Unigram {
if let Some(result) = self.cache.get(sentence) {
result.to_vec()
} else {
let result = self.encode_optimized(sentence);
let result = if self.is_optimized {
self.encode_optimized(sentence)
} else {
self.encode_unoptimized(sentence)
};
self.cache.set(sentence.to_owned(), result.clone());
result
}
@ -327,7 +340,6 @@ impl Unigram {
results
}
#[allow(dead_code)]
fn encode_unoptimized(&self, sentence: &str) -> Vec<String> {
let mut lattice = Lattice::from(sentence, self.unk_id, self.bos_id, self.eos_id);
self.populate_nodes(&mut lattice);
@ -541,12 +553,17 @@ mod tests {
];
let mut model = Unigram::from(sentencepieces, 0).unwrap();
for is_optimized in &[true, false] {
model.set_optimized(*is_optimized);
println!("IsOptimized {:?}", is_optimized);
assert_eq!(model.encode("abc"), vec!["abc"]);
assert_eq!(model.encode("AB"), vec!["AB"]);
model.set_fuse_unk(false);
assert_eq!(model.encode("AB"), vec!["A", "B"]);
model.set_fuse_unk(true);
assert_eq!(model.encode("AB"), vec!["AB"]);
assert_eq!(model.encode("abcd"), vec!["ab", "cd"]);
assert_eq!(model.encode("abcc"), vec!["abc", "c"]);
@ -557,6 +574,7 @@ mod tests {
model.set_fuse_unk(false);
assert_eq!(model.encode("xyz東京"), vec!["x", "y", "z", "", ""]);
model.set_fuse_unk(true);
assert_eq!(model.encode("xyz東京"), vec!["xyz東京"]);
// User encoded in original version
assert_eq!(model.encode("ABC"), vec!["ABC"]);
@ -564,4 +582,5 @@ mod tests {
assert_eq!(model.encode("ababcdabcdcd"), vec!["ab", "abcdabcd", "cd"]);
assert_eq!(model.encode("abqrcd"), vec!["ab", "q", "r", "cd"]);
}
}
}

View File

@ -1,14 +1,13 @@
#[cfg(not(debug_assertions))]
use assert_approx_eq::assert_approx_eq;
#[cfg(not(debug_assertions))]
use std::collections::HashMap;
#[cfg(not(debug_assertions))]
use std::fs::read_to_string;
use std::path::Path;
#[cfg(not(debug_assertions))]
use tokenizers::models::unigram::Lattice;
use tokenizers::models::unigram::Unigram;
use tokenizers::tokenizer::Model;
use tokenizers::models::unigram::UnigramTrainer;
use tokenizers::tokenizer::{Model, Trainer};
#[test]
fn test_unigram_from_file() {
@ -39,6 +38,26 @@ fn test_unigram_from_file() {
);
}
#[test]
fn test_train_unigram_from_file() {
let content = read_to_string("data/small.txt").unwrap();
let mut word_counts = HashMap::new();
content.split_whitespace().for_each(|word| {
// This is important for the test of char vs u8
let word = format!("{}", word.to_string());
*word_counts.entry(word).or_insert(0) += 1;
});
// println!("Words counts {:?}", word_counts);
let trainer = UnigramTrainer::builder()
.show_progress(false)
.build()
.unwrap();
let (model, _) = trainer.train(word_counts).unwrap();
assert_eq!(model.get_vocab_size(), 719);
}
#[cfg(not(debug_assertions))]
#[test]
fn test_sample() {