mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-23 16:49:27 +00:00
Fixing Trainer with u8 instead of chars. (#452)
* Fixing Trainer with u8 instead of chars. Now check both optimized and unoptimized encodings schemes for Unigram. * Small fixes. * Fixing makefile.
This commit is contained in:
@ -7,10 +7,9 @@ from ..utils import data_dir, train_files
|
|||||||
|
|
||||||
|
|
||||||
class TestUnigram:
|
class TestUnigram:
|
||||||
@pytest.mark.slow
|
|
||||||
def test_train(self, train_files):
|
def test_train(self, train_files):
|
||||||
tokenizer = SentencePieceUnigramTokenizer()
|
tokenizer = SentencePieceUnigramTokenizer()
|
||||||
tokenizer.train(train_files["big"], show_progress=False)
|
tokenizer.train(train_files["small"], show_progress=False)
|
||||||
|
|
||||||
filename = "tests/data/unigram_trained.json"
|
filename = "tests/data/unigram_trained.json"
|
||||||
tokenizer.save(filename)
|
tokenizer.save(filename)
|
||||||
|
@ -61,8 +61,17 @@ def openai_files(data_dir):
|
|||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def train_files(data_dir):
|
def train_files(data_dir):
|
||||||
|
big = download("https://norvig.com/big.txt")
|
||||||
|
small = os.path.join(DATA_PATH, "small.txt")
|
||||||
|
with open(small, "w") as f:
|
||||||
|
with open(big, "r") as g:
|
||||||
|
for i, line in enumerate(g):
|
||||||
|
f.write(line)
|
||||||
|
if i > 100:
|
||||||
|
break
|
||||||
return {
|
return {
|
||||||
"big": download("https://norvig.com/big.txt"),
|
"small": small,
|
||||||
|
"big": big,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -4,8 +4,8 @@ TESTS_DIR = tests
|
|||||||
|
|
||||||
dir_guard=@mkdir -p $(@D)
|
dir_guard=@mkdir -p $(@D)
|
||||||
|
|
||||||
SHARED_RESOURCES = $(DATA_DIR)/gpt2-vocab.json $(DATA_DIR)/gpt2-merges.txt $(DATA_DIR)/bert-base-uncased-vocab.txt
|
SHARED_RESOURCES = $(DATA_DIR)/gpt2-vocab.json $(DATA_DIR)/gpt2-merges.txt $(DATA_DIR)/bert-base-uncased-vocab.txt $(DATA_DIR)/big.txt $(DATA_DIR)/small.txt
|
||||||
BENCHMARK_RESOURCES = $(SHARED_RESOURCES) $(DATA_DIR)/big.txt $(DATA_DIR)/small.txt
|
BENCHMARK_RESOURCES = $(SHARED_RESOURCES)
|
||||||
TESTS_RESOURCES = $(SHARED_RESOURCES) $(DATA_DIR)/unigram.json $(DATA_DIR)/unigram_wagahaiwa_nekodearu.txt
|
TESTS_RESOURCES = $(SHARED_RESOURCES) $(DATA_DIR)/unigram.json $(DATA_DIR)/unigram_wagahaiwa_nekodearu.txt
|
||||||
|
|
||||||
.PHONY : build
|
.PHONY : build
|
||||||
|
@ -137,7 +137,7 @@ fn log_sum_exp(x: f64, y: f64, init_mode: bool) -> f64 {
|
|||||||
|
|
||||||
impl<'a> Lattice<'a> {
|
impl<'a> Lattice<'a> {
|
||||||
pub fn from(sentence: &'a str, unk_id: usize, bos_id: usize, eos_id: usize) -> Lattice<'a> {
|
pub fn from(sentence: &'a str, unk_id: usize, bos_id: usize, eos_id: usize) -> Lattice<'a> {
|
||||||
let len = sentence.chars().count();
|
let len = sentence.bytes().count();
|
||||||
let k_reserved_node_size = 16;
|
let k_reserved_node_size = 16;
|
||||||
// We are adding 2 tokens, bos and eos
|
// We are adding 2 tokens, bos and eos
|
||||||
let mut nodes: Vec<NodeRef> = Vec::with_capacity(k_reserved_node_size);
|
let mut nodes: Vec<NodeRef> = Vec::with_capacity(k_reserved_node_size);
|
||||||
@ -177,7 +177,8 @@ impl<'a> Lattice<'a> {
|
|||||||
|
|
||||||
pub fn viterbi(&mut self) -> Vec<NodeRef> {
|
pub fn viterbi(&mut self) -> Vec<NodeRef> {
|
||||||
let len = self.len;
|
let len = self.len;
|
||||||
for pos in 0..=len {
|
let mut pos = 0;
|
||||||
|
while pos <= len {
|
||||||
if self.begin_nodes[pos].is_empty() {
|
if self.begin_nodes[pos].is_empty() {
|
||||||
return vec![];
|
return vec![];
|
||||||
}
|
}
|
||||||
@ -201,6 +202,11 @@ impl<'a> Lattice<'a> {
|
|||||||
None => return vec![],
|
None => return vec![],
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if let Some(c) = self.sentence[pos..].chars().next() {
|
||||||
|
pos += c.len_utf8();
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut results: Vec<NodeRef> = vec![];
|
let mut results: Vec<NodeRef> = vec![];
|
||||||
@ -220,11 +226,7 @@ impl<'a> Lattice<'a> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn piece(&self, node: &Node) -> String {
|
pub fn piece(&self, node: &Node) -> String {
|
||||||
self.sentence
|
self.sentence[node.pos..node.pos + node.length].to_owned()
|
||||||
.chars()
|
|
||||||
.skip(node.pos)
|
|
||||||
.take(node.length)
|
|
||||||
.collect()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn tokens(&mut self) -> Vec<String> {
|
pub fn tokens(&mut self) -> Vec<String> {
|
||||||
@ -306,14 +308,11 @@ impl<'a> Lattice<'a> {
|
|||||||
pub fn len(&self) -> usize {
|
pub fn len(&self) -> usize {
|
||||||
self.len
|
self.len
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn is_empty(&self) -> bool {
|
pub fn is_empty(&self) -> bool {
|
||||||
self.len == 0
|
self.len == 0
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn utf8_len(&self) -> usize {
|
|
||||||
self.sentence.len()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn bos_node(&self) -> NodeRef {
|
pub fn bos_node(&self) -> NodeRef {
|
||||||
Rc::clone(&self.end_nodes[0][0])
|
Rc::clone(&self.end_nodes[0][0])
|
||||||
}
|
}
|
||||||
@ -443,18 +442,14 @@ mod tests {
|
|||||||
let lattice = Lattice::from("", 0, 1, 2);
|
let lattice = Lattice::from("", 0, 1, 2);
|
||||||
|
|
||||||
assert_eq!(lattice.len(), 0);
|
assert_eq!(lattice.len(), 0);
|
||||||
assert_eq!(lattice.utf8_len(), 0);
|
|
||||||
// EXPECT_EQ(0, lattice.utf8_size());
|
|
||||||
|
|
||||||
let lattice = Lattice::from("", 0, 1, 2);
|
let lattice = Lattice::from("", 0, 1, 2);
|
||||||
assert_eq!(lattice.len(), 0);
|
assert_eq!(lattice.len(), 0);
|
||||||
assert_eq!(lattice.utf8_len(), 0);
|
|
||||||
assert_eq!(lattice.sentence(), "");
|
assert_eq!(lattice.sentence(), "");
|
||||||
assert_eq!(lattice.surface(0), "");
|
assert_eq!(lattice.surface(0), "");
|
||||||
|
|
||||||
let lattice = Lattice::from("test", 0, 1, 2);
|
let lattice = Lattice::from("test", 0, 1, 2);
|
||||||
assert_eq!(lattice.len(), 4);
|
assert_eq!(lattice.len(), 4);
|
||||||
assert_eq!(lattice.utf8_len(), 4);
|
|
||||||
assert_eq!(lattice.sentence(), "test");
|
assert_eq!(lattice.sentence(), "test");
|
||||||
assert_eq!(lattice.surface(0), "test");
|
assert_eq!(lattice.surface(0), "test");
|
||||||
assert_eq!(lattice.surface(1), "est");
|
assert_eq!(lattice.surface(1), "est");
|
||||||
@ -476,8 +471,7 @@ mod tests {
|
|||||||
);
|
);
|
||||||
|
|
||||||
let lattice = Lattice::from("テストab", 0, 1, 2);
|
let lattice = Lattice::from("テストab", 0, 1, 2);
|
||||||
assert_eq!(lattice.len(), 5);
|
assert_eq!(lattice.len(), 11);
|
||||||
assert_eq!(lattice.utf8_len(), 11);
|
|
||||||
assert_eq!(lattice.sentence(), "テストab");
|
assert_eq!(lattice.sentence(), "テストab");
|
||||||
assert_eq!(lattice.surface(0), "テストab");
|
assert_eq!(lattice.surface(0), "テストab");
|
||||||
assert_eq!(lattice.surface(1), "ストab");
|
assert_eq!(lattice.surface(1), "ストab");
|
||||||
@ -492,11 +486,11 @@ mod tests {
|
|||||||
|
|
||||||
lattice.insert(0, 1, 0.0, 3);
|
lattice.insert(0, 1, 0.0, 3);
|
||||||
lattice.insert(1, 1, 0.0, 4);
|
lattice.insert(1, 1, 0.0, 4);
|
||||||
lattice.insert(2, 1, 0.0, 5);
|
lattice.insert(2, 3, 0.0, 5);
|
||||||
lattice.insert(3, 1, 0.0, 6);
|
lattice.insert(5, 3, 0.0, 6);
|
||||||
lattice.insert(0, 2, 0.0, 7);
|
lattice.insert(0, 2, 0.0, 7);
|
||||||
lattice.insert(1, 2, 0.0, 8);
|
lattice.insert(1, 4, 0.0, 8);
|
||||||
lattice.insert(2, 2, 0.0, 9);
|
lattice.insert(2, 6, 0.0, 9);
|
||||||
// 0 & 1 are bos and eos
|
// 0 & 1 are bos and eos
|
||||||
let node0 = lattice.nodes[2].borrow();
|
let node0 = lattice.nodes[2].borrow();
|
||||||
let node1 = lattice.nodes[3].borrow();
|
let node1 = lattice.nodes[3].borrow();
|
||||||
@ -517,18 +511,18 @@ mod tests {
|
|||||||
assert_eq!(node0.pos, 0);
|
assert_eq!(node0.pos, 0);
|
||||||
assert_eq!(node1.pos, 1);
|
assert_eq!(node1.pos, 1);
|
||||||
assert_eq!(node2.pos, 2);
|
assert_eq!(node2.pos, 2);
|
||||||
assert_eq!(node3.pos, 3);
|
assert_eq!(node3.pos, 5);
|
||||||
assert_eq!(node4.pos, 0);
|
assert_eq!(node4.pos, 0);
|
||||||
assert_eq!(node5.pos, 1);
|
assert_eq!(node5.pos, 1);
|
||||||
assert_eq!(node6.pos, 2);
|
assert_eq!(node6.pos, 2);
|
||||||
|
|
||||||
assert_eq!(node0.length, 1);
|
assert_eq!(node0.length, 1);
|
||||||
assert_eq!(node1.length, 1);
|
assert_eq!(node1.length, 1);
|
||||||
assert_eq!(node2.length, 1);
|
assert_eq!(node2.length, 3);
|
||||||
assert_eq!(node3.length, 1);
|
assert_eq!(node3.length, 3);
|
||||||
assert_eq!(node4.length, 2);
|
assert_eq!(node4.length, 2);
|
||||||
assert_eq!(node5.length, 2);
|
assert_eq!(node5.length, 4);
|
||||||
assert_eq!(node6.length, 2);
|
assert_eq!(node6.length, 6);
|
||||||
|
|
||||||
assert_eq!(lattice.bos_node().borrow().id, 1);
|
assert_eq!(lattice.bos_node().borrow().id, 1);
|
||||||
assert_eq!(lattice.eos_node().borrow().id, 2);
|
assert_eq!(lattice.eos_node().borrow().id, 2);
|
||||||
@ -543,14 +537,14 @@ mod tests {
|
|||||||
assert_eq!(lattice.begin_nodes[0].len(), 2);
|
assert_eq!(lattice.begin_nodes[0].len(), 2);
|
||||||
assert_eq!(lattice.begin_nodes[1].len(), 2);
|
assert_eq!(lattice.begin_nodes[1].len(), 2);
|
||||||
assert_eq!(lattice.begin_nodes[2].len(), 2);
|
assert_eq!(lattice.begin_nodes[2].len(), 2);
|
||||||
assert_eq!(lattice.begin_nodes[3].len(), 1);
|
assert_eq!(lattice.begin_nodes[5].len(), 1);
|
||||||
assert_eq!(lattice.begin_nodes[4].len(), 1);
|
assert_eq!(lattice.begin_nodes[8].len(), 1);
|
||||||
|
|
||||||
assert_eq!(lattice.end_nodes[0].len(), 1);
|
assert_eq!(lattice.end_nodes[0].len(), 1);
|
||||||
assert_eq!(lattice.end_nodes[1].len(), 1);
|
assert_eq!(lattice.end_nodes[1].len(), 1);
|
||||||
assert_eq!(lattice.end_nodes[2].len(), 2);
|
assert_eq!(lattice.end_nodes[2].len(), 2);
|
||||||
assert_eq!(lattice.end_nodes[3].len(), 2);
|
assert_eq!(lattice.end_nodes[5].len(), 2);
|
||||||
assert_eq!(lattice.end_nodes[4].len(), 2);
|
assert_eq!(lattice.end_nodes[8].len(), 2);
|
||||||
|
|
||||||
assert_eq!(lattice.begin_nodes[0][0].borrow().id, node0.id);
|
assert_eq!(lattice.begin_nodes[0][0].borrow().id, node0.id);
|
||||||
assert_eq!(lattice.begin_nodes[0][1].borrow().id, node4.id);
|
assert_eq!(lattice.begin_nodes[0][1].borrow().id, node4.id);
|
||||||
@ -558,10 +552,10 @@ mod tests {
|
|||||||
assert_eq!(lattice.begin_nodes[1][1].borrow().id, node5.id);
|
assert_eq!(lattice.begin_nodes[1][1].borrow().id, node5.id);
|
||||||
assert_eq!(lattice.begin_nodes[2][0].borrow().id, node2.id);
|
assert_eq!(lattice.begin_nodes[2][0].borrow().id, node2.id);
|
||||||
assert_eq!(lattice.begin_nodes[2][1].borrow().id, node6.id);
|
assert_eq!(lattice.begin_nodes[2][1].borrow().id, node6.id);
|
||||||
assert_eq!(lattice.begin_nodes[3][0].borrow().id, node3.id);
|
assert_eq!(lattice.begin_nodes[5][0].borrow().id, node3.id);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
lattice.eos_node().borrow().id,
|
lattice.eos_node().borrow().id,
|
||||||
lattice.begin_nodes[4][0].borrow().id
|
lattice.begin_nodes[8][0].borrow().id
|
||||||
);
|
);
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@ -571,10 +565,10 @@ mod tests {
|
|||||||
assert_eq!(node0.id, lattice.end_nodes[1][0].borrow().id);
|
assert_eq!(node0.id, lattice.end_nodes[1][0].borrow().id);
|
||||||
assert_eq!(node1.id, lattice.end_nodes[2][0].borrow().id);
|
assert_eq!(node1.id, lattice.end_nodes[2][0].borrow().id);
|
||||||
assert_eq!(node4.id, lattice.end_nodes[2][1].borrow().id);
|
assert_eq!(node4.id, lattice.end_nodes[2][1].borrow().id);
|
||||||
assert_eq!(node2.id, lattice.end_nodes[3][0].borrow().id);
|
assert_eq!(node2.id, lattice.end_nodes[5][0].borrow().id);
|
||||||
assert_eq!(node5.id, lattice.end_nodes[3][1].borrow().id);
|
assert_eq!(node5.id, lattice.end_nodes[5][1].borrow().id);
|
||||||
assert_eq!(node3.id, lattice.end_nodes[4][0].borrow().id);
|
assert_eq!(node3.id, lattice.end_nodes[8][0].borrow().id);
|
||||||
assert_eq!(node6.id, lattice.end_nodes[4][1].borrow().id);
|
assert_eq!(node6.id, lattice.end_nodes[8][1].borrow().id);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
@ -24,6 +24,7 @@ pub struct Unigram {
|
|||||||
pub(super) eos_id: usize,
|
pub(super) eos_id: usize,
|
||||||
|
|
||||||
fuse_unk: bool,
|
fuse_unk: bool,
|
||||||
|
is_optimized: bool,
|
||||||
}
|
}
|
||||||
impl PartialEq for Unigram {
|
impl PartialEq for Unigram {
|
||||||
fn eq(&self, other: &Self) -> bool {
|
fn eq(&self, other: &Self) -> bool {
|
||||||
@ -46,6 +47,7 @@ impl Clone for Unigram {
|
|||||||
bos_id: self.bos_id,
|
bos_id: self.bos_id,
|
||||||
eos_id: self.eos_id,
|
eos_id: self.eos_id,
|
||||||
fuse_unk: self.fuse_unk,
|
fuse_unk: self.fuse_unk,
|
||||||
|
is_optimized: self.is_optimized,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -122,6 +124,7 @@ impl Unigram {
|
|||||||
}
|
}
|
||||||
let trie = builder.build();
|
let trie = builder.build();
|
||||||
let fuse_unk = true;
|
let fuse_unk = true;
|
||||||
|
let is_optimized = true;
|
||||||
|
|
||||||
Ok(Unigram {
|
Ok(Unigram {
|
||||||
vocab,
|
vocab,
|
||||||
@ -133,6 +136,7 @@ impl Unigram {
|
|||||||
unk_id,
|
unk_id,
|
||||||
fuse_unk,
|
fuse_unk,
|
||||||
cache: Cache::default(),
|
cache: Cache::default(),
|
||||||
|
is_optimized,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -142,6 +146,11 @@ impl Unigram {
|
|||||||
self.cache = self.cache.fresh();
|
self.cache = self.cache.fresh();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
pub(super) fn set_optimized(&mut self, is_optimized: bool) {
|
||||||
|
self.is_optimized = is_optimized;
|
||||||
|
}
|
||||||
|
|
||||||
pub(super) fn len(&self) -> usize {
|
pub(super) fn len(&self) -> usize {
|
||||||
self.vocab.len()
|
self.vocab.len()
|
||||||
}
|
}
|
||||||
@ -179,7 +188,7 @@ impl Unigram {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !has_single_node {
|
if !has_single_node {
|
||||||
lattice.insert(begin_pos, 1, unk_score, self.unk_id);
|
lattice.insert(begin_pos, mblen, unk_score, self.unk_id);
|
||||||
}
|
}
|
||||||
begin_pos += mblen
|
begin_pos += mblen
|
||||||
}
|
}
|
||||||
@ -212,7 +221,11 @@ impl Unigram {
|
|||||||
if let Some(result) = self.cache.get(sentence) {
|
if let Some(result) = self.cache.get(sentence) {
|
||||||
result.to_vec()
|
result.to_vec()
|
||||||
} else {
|
} else {
|
||||||
let result = self.encode_optimized(sentence);
|
let result = if self.is_optimized {
|
||||||
|
self.encode_optimized(sentence)
|
||||||
|
} else {
|
||||||
|
self.encode_unoptimized(sentence)
|
||||||
|
};
|
||||||
self.cache.set(sentence.to_owned(), result.clone());
|
self.cache.set(sentence.to_owned(), result.clone());
|
||||||
result
|
result
|
||||||
}
|
}
|
||||||
@ -327,7 +340,6 @@ impl Unigram {
|
|||||||
results
|
results
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(dead_code)]
|
|
||||||
fn encode_unoptimized(&self, sentence: &str) -> Vec<String> {
|
fn encode_unoptimized(&self, sentence: &str) -> Vec<String> {
|
||||||
let mut lattice = Lattice::from(sentence, self.unk_id, self.bos_id, self.eos_id);
|
let mut lattice = Lattice::from(sentence, self.unk_id, self.bos_id, self.eos_id);
|
||||||
self.populate_nodes(&mut lattice);
|
self.populate_nodes(&mut lattice);
|
||||||
@ -541,27 +553,34 @@ mod tests {
|
|||||||
];
|
];
|
||||||
|
|
||||||
let mut model = Unigram::from(sentencepieces, 0).unwrap();
|
let mut model = Unigram::from(sentencepieces, 0).unwrap();
|
||||||
assert_eq!(model.encode("abc"), vec!["abc"]);
|
|
||||||
assert_eq!(model.encode("AB"), vec!["AB"]);
|
|
||||||
|
|
||||||
model.set_fuse_unk(false);
|
for is_optimized in &[true, false] {
|
||||||
assert_eq!(model.encode("AB"), vec!["A", "B"]);
|
model.set_optimized(*is_optimized);
|
||||||
model.set_fuse_unk(true);
|
println!("IsOptimized {:?}", is_optimized);
|
||||||
|
assert_eq!(model.encode("abc"), vec!["abc"]);
|
||||||
|
assert_eq!(model.encode("AB"), vec!["AB"]);
|
||||||
|
|
||||||
assert_eq!(model.encode("abcd"), vec!["ab", "cd"]);
|
model.set_fuse_unk(false);
|
||||||
assert_eq!(model.encode("abcc"), vec!["abc", "c"]);
|
assert_eq!(model.encode("AB"), vec!["A", "B"]);
|
||||||
assert_eq!(
|
model.set_fuse_unk(true);
|
||||||
model.encode("xabcabaabcdd"),
|
assert_eq!(model.encode("AB"), vec!["AB"]);
|
||||||
vec!["x", "abc", "ab", "a", "ab", "cd", "d"]
|
|
||||||
);
|
|
||||||
model.set_fuse_unk(false);
|
|
||||||
assert_eq!(model.encode("xyz東京"), vec!["x", "y", "z", "東", "京"]);
|
|
||||||
model.set_fuse_unk(true);
|
|
||||||
|
|
||||||
// User encoded in original version
|
assert_eq!(model.encode("abcd"), vec!["ab", "cd"]);
|
||||||
assert_eq!(model.encode("ABC"), vec!["ABC"]);
|
assert_eq!(model.encode("abcc"), vec!["abc", "c"]);
|
||||||
assert_eq!(model.encode("abABCcd"), vec!["ab", "ABC", "cd"]);
|
assert_eq!(
|
||||||
assert_eq!(model.encode("ababcdabcdcd"), vec!["ab", "abcdabcd", "cd"]);
|
model.encode("xabcabaabcdd"),
|
||||||
assert_eq!(model.encode("abqrcd"), vec!["ab", "q", "r", "cd"]);
|
vec!["x", "abc", "ab", "a", "ab", "cd", "d"]
|
||||||
|
);
|
||||||
|
model.set_fuse_unk(false);
|
||||||
|
assert_eq!(model.encode("xyz東京"), vec!["x", "y", "z", "東", "京"]);
|
||||||
|
model.set_fuse_unk(true);
|
||||||
|
assert_eq!(model.encode("xyz東京"), vec!["xyz東京"]);
|
||||||
|
|
||||||
|
// User encoded in original version
|
||||||
|
assert_eq!(model.encode("ABC"), vec!["ABC"]);
|
||||||
|
assert_eq!(model.encode("abABCcd"), vec!["ab", "ABC", "cd"]);
|
||||||
|
assert_eq!(model.encode("ababcdabcdcd"), vec!["ab", "abcdabcd", "cd"]);
|
||||||
|
assert_eq!(model.encode("abqrcd"), vec!["ab", "q", "r", "cd"]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,14 +1,13 @@
|
|||||||
#[cfg(not(debug_assertions))]
|
#[cfg(not(debug_assertions))]
|
||||||
use assert_approx_eq::assert_approx_eq;
|
use assert_approx_eq::assert_approx_eq;
|
||||||
#[cfg(not(debug_assertions))]
|
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
#[cfg(not(debug_assertions))]
|
|
||||||
use std::fs::read_to_string;
|
use std::fs::read_to_string;
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
#[cfg(not(debug_assertions))]
|
#[cfg(not(debug_assertions))]
|
||||||
use tokenizers::models::unigram::Lattice;
|
use tokenizers::models::unigram::Lattice;
|
||||||
use tokenizers::models::unigram::Unigram;
|
use tokenizers::models::unigram::Unigram;
|
||||||
use tokenizers::tokenizer::Model;
|
use tokenizers::models::unigram::UnigramTrainer;
|
||||||
|
use tokenizers::tokenizer::{Model, Trainer};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_unigram_from_file() {
|
fn test_unigram_from_file() {
|
||||||
@ -39,6 +38,26 @@ fn test_unigram_from_file() {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_train_unigram_from_file() {
|
||||||
|
let content = read_to_string("data/small.txt").unwrap();
|
||||||
|
let mut word_counts = HashMap::new();
|
||||||
|
content.split_whitespace().for_each(|word| {
|
||||||
|
// This is important for the test of char vs u8
|
||||||
|
let word = format!("▁{}", word.to_string());
|
||||||
|
*word_counts.entry(word).or_insert(0) += 1;
|
||||||
|
});
|
||||||
|
|
||||||
|
// println!("Words counts {:?}", word_counts);
|
||||||
|
|
||||||
|
let trainer = UnigramTrainer::builder()
|
||||||
|
.show_progress(false)
|
||||||
|
.build()
|
||||||
|
.unwrap();
|
||||||
|
let (model, _) = trainer.train(word_counts).unwrap();
|
||||||
|
assert_eq!(model.get_vocab_size(), 719);
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(not(debug_assertions))]
|
#[cfg(not(debug_assertions))]
|
||||||
#[test]
|
#[test]
|
||||||
fn test_sample() {
|
fn test_sample() {
|
||||||
|
Reference in New Issue
Block a user