Unigram - Add special_tokens at the end of training + optional unk

This commit is contained in:
Anthony MOI
2020-10-23 12:18:57 -04:00
committed by Anthony MOI
parent 390ef2f9f3
commit 73b5da917f
7 changed files with 256 additions and 115 deletions

View File

@@ -416,9 +416,8 @@ pub struct PyUnigram {}
impl PyUnigram {
#[new]
fn new(vocab: Option<Vec<(String, f64)>>, unk_id: Option<usize>) -> PyResult<(Self, PyModel)> {
if vocab.is_some() && unk_id.is_none() || vocab.is_none() && unk_id.is_some() {}
match (vocab, unk_id) {
(Some(vocab), Some(unk_id)) => {
(Some(vocab), unk_id) => {
let model = Unigram::from(vocab, unk_id).map_err(|e| {
exceptions::PyException::new_err(format!("Error while loading Unigram: {}", e))
})?;

View File

@@ -42,3 +42,53 @@ class TestUnigram:
trainer = trainers.BpeTrainer(special_tokens=["<unk>"], show_progress=False)
bpe_tokenizer.train(trainer, [train_files["small"]])
def test_train_with_special_tokens(self):
filename = "tests/data/dummy-unigram-special_tokens-train.txt"
with open(filename, "w") as f:
f.write(
"""
[CLS] The Zen of Python, by Tim Peters [SEP]
[CLS] Beautiful is better than ugly. [SEP]
[CLS] Explicit is better than implicit. [SEP]
[CLS] Simple is better than complex. [SEP]
[CLS] Complex is better than complicated. [SEP]
[CLS] Flat is better than nested. [SEP]
[CLS] Sparse is better than dense. [SEP]
[CLS] Readability counts. [SEP]
[CLS] Special cases aren't special enough to break the rules. [SEP]
[CLS] Although practicality beats purity. [SEP]
[CLS] Errors should never pass silently. [SEP]
[CLS] Unless explicitly silenced. [SEP]
[CLS] In the face of ambiguity, refuse the temptation to guess. [SEP]
[CLS] There should be one-- and preferably only one --obvious way to do it. [SEP]
[CLS] Although that way may not be obvious at first unless you're Dutch. [SEP]
[CLS] Now is better than never. [SEP]
[CLS] Although never is often better than *right* now. [SEP]
[CLS] If the implementation is hard to explain, it's a bad idea. [SEP]
[CLS] If the implementation is easy to explain, it may be a good idea. [SEP]
[CLS] Namespaces are one honking great idea -- let's do more of those! [SEP]
"""
)
tokenizer = Tokenizer(models.Unigram())
trainer = trainers.UnigramTrainer(
show_progress=False, special_tokens=["[PAD]", "[SEP]", "[CLS]"], unk_token="[UNK]"
)
tokenizer.train(trainer, [filename])
assert tokenizer.encode("[CLS] This is a test [SEP]").tokens == [
"[CLS]",
" T",
"h",
"i",
"s",
" is ",
"a",
" ",
"t",
"es",
"t ",
"[SEP]",
]

View File

@@ -58,7 +58,6 @@ pub struct Lattice<'a> {
pub(super) end_nodes: Vec<Vec<NodeRef>>,
bos_id: usize,
eos_id: usize,
unk_id: usize,
}
impl std::fmt::Display for Lattice<'_> {
@@ -136,7 +135,7 @@ fn log_sum_exp(x: f64, y: f64, init_mode: bool) -> f64 {
}
impl<'a> Lattice<'a> {
pub fn from(sentence: &'a str, unk_id: usize, bos_id: usize, eos_id: usize) -> Lattice<'a> {
pub fn from(sentence: &'a str, bos_id: usize, eos_id: usize) -> Lattice<'a> {
let len = sentence.bytes().count();
let k_reserved_node_size = 16;
// We are adding 2 tokens, bos and eos
@@ -161,7 +160,6 @@ impl<'a> Lattice<'a> {
end_nodes,
bos_id,
eos_id,
unk_id,
}
}
@@ -439,16 +437,16 @@ mod tests {
#[test]
fn set_sentence() {
let lattice = Lattice::from("", 0, 1, 2);
let lattice = Lattice::from("", 1, 2);
assert_eq!(lattice.len(), 0);
let lattice = Lattice::from("", 0, 1, 2);
let lattice = Lattice::from("", 1, 2);
assert_eq!(lattice.len(), 0);
assert_eq!(lattice.sentence(), "");
assert_eq!(lattice.surface(0), "");
let lattice = Lattice::from("test", 0, 1, 2);
let lattice = Lattice::from("test", 1, 2);
assert_eq!(lattice.len(), 4);
assert_eq!(lattice.sentence(), "test");
assert_eq!(lattice.surface(0), "test");
@@ -470,7 +468,7 @@ mod tests {
eos.borrow().id
);
let lattice = Lattice::from("テストab", 0, 1, 2);
let lattice = Lattice::from("テストab", 1, 2);
assert_eq!(lattice.len(), 11);
assert_eq!(lattice.sentence(), "テストab");
assert_eq!(lattice.surface(0), "テストab");
@@ -482,7 +480,7 @@ mod tests {
#[test]
fn insert_test() {
let mut lattice = Lattice::from("ABあい", 0, 1, 2);
let mut lattice = Lattice::from("ABあい", 1, 2);
lattice.insert(0, 1, 0.0, 3);
lattice.insert(1, 1, 0.0, 4);
@@ -573,7 +571,7 @@ mod tests {
#[test]
fn test_viterbi() {
let mut lattice = Lattice::from("ABC", 0, 1, 2);
let mut lattice = Lattice::from("ABC", 1, 2);
assert_eq!(lattice.viterbi(), vec![]);
// Still incomplete
lattice.insert(0, 1, 0.0, 3);
@@ -586,7 +584,7 @@ mod tests {
#[test]
fn test_viterbi2() {
let mut lattice = Lattice::from("ABC", 0, 1, 2);
let mut lattice = Lattice::from("ABC", 1, 2);
lattice.insert(0, 1, 0.0, 3);
lattice.insert(1, 1, 0.0, 4);
@@ -606,7 +604,7 @@ mod tests {
#[test]
fn test_nbest() {
let mut lattice = Lattice::from("ABC", 0, 1, 2);
let mut lattice = Lattice::from("ABC", 1, 2);
lattice.insert(0, 1, 0.0, 3);
lattice.insert(1, 1, 0.0, 4);
lattice.insert(2, 1, 0.0, 5);
@@ -641,7 +639,7 @@ mod tests {
#[test]
fn test_populate() {
let mut lattice = Lattice::from("ABC", 0, 1, 2);
let mut lattice = Lattice::from("ABC", 1, 2);
lattice.insert(0, 1, 1.0, 3); // A
lattice.insert(1, 1, 1.2, 4); // B
lattice.insert(2, 1, 2.5, 5); // C

View File

@@ -18,7 +18,7 @@ pub struct Unigram {
cache: Cache<String, Vec<String>>,
trie: Trie<u8>,
pub min_score: f64,
pub(super) unk_id: usize,
pub(super) unk_id: Option<usize>,
pub(super) bos_id: usize,
pub(super) eos_id: usize,
@@ -54,7 +54,7 @@ impl Clone for Unigram {
impl std::fmt::Debug for Unigram {
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
fmt.debug_struct("Unigram")
.field("vocab", &self.vocab)
.field("vocab", &self.vocab.len())
.field("unk_id", &self.unk_id)
.finish()
}
@@ -66,6 +66,7 @@ static K_UNK_PENALTY: f64 = 10.0;
pub enum UnigramError {
EmptyVocabulary,
UnkIdNotInVocabulary,
MissingUnkId,
}
impl std::fmt::Display for UnigramError {
@@ -77,6 +78,9 @@ impl std::fmt::Display for UnigramError {
UnigramError::UnkIdNotInVocabulary => {
write!(f, "The `unk_id` is larger than vocabulary size")
}
UnigramError::MissingUnkId => {
write!(f, "Encountered an unknown token but `unk_id` is missing")
}
}
}
}
@@ -86,7 +90,7 @@ impl std::error::Error for UnigramError {}
impl Default for Unigram {
fn default() -> Self {
let vocab = vec![("<unk>".to_string(), 0.0)];
Self::from(vocab, 0).unwrap()
Self::from(vocab, Some(0)).unwrap()
}
}
@@ -97,16 +101,18 @@ impl Unigram {
/// unk_id, is the index within the vocabulary.
/// For now `Unigram` *requires* at least `unk` because we might find a never seen char.
/// Further versions might allow that part to be hidden.
pub fn from(vocab: Vec<(String, f64)>, unk_id: usize) -> Result<Self> {
pub fn from(vocab: Vec<(String, f64)>, unk_id: Option<usize>) -> Result<Self> {
let n = vocab.len();
let mut token_to_ids: TokenMap = HashMap::new();
let mut builder = TrieBuilder::default();
if vocab.is_empty() {
return Err(Box::new(UnigramError::EmptyVocabulary));
}
if unk_id >= vocab.len() {
return Err(Box::new(UnigramError::UnkIdNotInVocabulary));
if let Some(unk_id) = unk_id {
if vocab.is_empty() {
return Err(Box::new(UnigramError::EmptyVocabulary));
}
if unk_id >= vocab.len() {
return Err(Box::new(UnigramError::UnkIdNotInVocabulary));
}
}
let bos_id = n + 1;
@@ -187,7 +193,9 @@ impl Unigram {
}
if !has_single_node {
lattice.insert(begin_pos, mblen, unk_score, self.unk_id);
if let Some(unk_id) = self.unk_id {
lattice.insert(begin_pos, mblen, unk_score, unk_id);
}
}
begin_pos += mblen
}
@@ -209,28 +217,28 @@ impl Unigram {
/// ("abc".to_string(), 5.0),
/// ("abcd".to_string(), 10.0),
/// ];
/// let model = Unigram::from(pieces, 0).unwrap();
/// let result = model.encode("abcdacdxx");
/// let model = Unigram::from(pieces, Some(0)).unwrap();
/// let result = model.encode("abcdacdxx").unwrap();
/// assert_eq!(result, vec!["abcd", "a", "cd", "xx"]);
/// ```
pub fn encode(&self, sentence: &str) -> Vec<String> {
pub fn encode(&self, sentence: &str) -> Result<Vec<String>> {
if sentence.is_empty() {
return vec![];
return Ok(vec![]);
}
if let Some(result) = self.cache.get(sentence) {
result.to_vec()
Ok(result.to_vec())
} else {
let result = if self.is_optimized {
self.encode_optimized(sentence)
self.encode_optimized(sentence)?
} else {
self.encode_unoptimized(sentence)
self.encode_unoptimized(sentence)?
};
self.cache.set(sentence.to_owned(), result.clone());
result
Ok(result)
}
}
fn encode_optimized(&self, sentence: &str) -> Vec<String> {
fn encode_optimized(&self, sentence: &str) -> Result<Vec<String>> {
// https://github.com/google/sentencepiece/blob/d48247191a6d50e469ed1a4a36e877befffd1851/src/unigram_model.cc#L600
#[derive(Debug, Clone)]
struct BestPathNode {
@@ -290,7 +298,7 @@ impl Unigram {
{
target_node.best_path_score = candidate_best_path_score;
target_node.starts_at = Some(starts_at);
target_node.id = self.unk_id;
target_node.id = self.unk_id.ok_or(UnigramError::MissingUnkId)?;
}
}
starts_at += mblen
@@ -301,16 +309,9 @@ impl Unigram {
while ends_at > 0 {
let node = &best_path_ends_at[ends_at];
let starts_at = node.starts_at.unwrap();
if self.fuse_unk && node.id == self.unk_id {
if self.fuse_unk && node.id == self.unk_id.ok_or(UnigramError::MissingUnkId)? {
token.push(
String::from_utf8(
sentence
.bytes()
.skip(starts_at)
.take(ends_at - starts_at)
.collect(),
)
.unwrap(),
String::from_utf8(sentence[starts_at..ends_at].as_bytes().to_vec()).unwrap(),
);
} else {
if !token.is_empty() {
@@ -319,14 +320,7 @@ impl Unigram {
token = vec![];
}
results.push(
String::from_utf8(
sentence
.bytes()
.skip(starts_at)
.take(ends_at - starts_at)
.collect(),
)
.unwrap(),
String::from_utf8(sentence[starts_at..ends_at].as_bytes().to_vec()).unwrap(),
);
}
ends_at = starts_at;
@@ -336,18 +330,18 @@ impl Unigram {
results.push(token.concat());
}
results.reverse();
results
Ok(results)
}
fn encode_unoptimized(&self, sentence: &str) -> Vec<String> {
let mut lattice = Lattice::from(sentence, self.unk_id, self.bos_id, self.eos_id);
fn encode_unoptimized(&self, sentence: &str) -> Result<Vec<String>> {
let mut lattice = Lattice::from(sentence, self.bos_id, self.eos_id);
self.populate_nodes(&mut lattice);
if self.fuse_unk {
let mut results = vec![];
let mut token = String::new();
for node in lattice.viterbi().iter() {
let item = lattice.piece(&node.borrow());
if node.borrow().id == self.unk_id {
if node.borrow().id == self.unk_id.ok_or(UnigramError::MissingUnkId)? {
token.push_str(&item);
} else {
if !token.is_empty() {
@@ -360,9 +354,9 @@ impl Unigram {
if !token.is_empty() {
results.push(token);
}
results
Ok(results)
} else {
lattice.tokens()
Ok(lattice.tokens())
}
}
@@ -416,21 +410,20 @@ impl Model for Unigram {
}
fn tokenize(&self, sentence: &str) -> Result<Vec<Token>> {
let tokens = self.encode(sentence);
let str_tokens = self.encode(sentence)?;
let mut offset = 0;
Ok(tokens
.iter()
.map(|string| {
let id: u32 = match self.token_to_ids.get(string) {
Some(id) => *id,
None => self.unk_id as u32,
};
let len = string.len();
let offsets = (offset, offset + len);
offset += len;
Token::new(id, string.to_string(), offsets)
})
.collect())
let mut tokens = Vec::with_capacity(str_tokens.len());
for string in str_tokens {
let id: u32 = match self.token_to_ids.get(&string) {
Some(id) => *id,
None => self.unk_id.ok_or(UnigramError::MissingUnkId)? as u32,
};
let len = string.len();
let offsets = (offset, offset + len);
offset += len;
tokens.push(Token::new(id, string, offsets));
}
Ok(tokens)
}
fn token_to_id(&self, token: &str) -> Option<u32> {
@@ -465,9 +458,9 @@ mod tests {
#[test]
fn test_populate_nodes_unk() {
let pieces = vec![("<unk>".to_string(), 0.0)];
let model = Unigram::from(pieces, 0).unwrap();
let model = Unigram::from(pieces, Some(0)).unwrap();
let mut lattice = Lattice::from("abc", 0, model.bos_id, model.eos_id);
let mut lattice = Lattice::from("abc", model.bos_id, model.eos_id);
model.populate_nodes(&mut lattice);
assert_eq!(lattice.begin_nodes[0].len(), 1);
@@ -490,9 +483,9 @@ mod tests {
("ab".to_string(), 0.3),
("bc".to_string(), 0.4),
];
let model = Unigram::from(pieces, 0).unwrap();
let model = Unigram::from(pieces, Some(0)).unwrap();
let mut lattice = Lattice::from("abc", 0, model.bos_id, model.eos_id);
let mut lattice = Lattice::from("abc", model.bos_id, model.eos_id);
model.populate_nodes(&mut lattice);
assert_eq!(lattice.begin_nodes[0].len(), 2); // a, ab
@@ -527,8 +520,8 @@ mod tests {
("abcd".to_string(), 10.0),
];
let model = Unigram::from(sentencepieces, 0).unwrap();
let result = model.encode("abcd");
let model = Unigram::from(sentencepieces, Some(0)).unwrap();
let result = model.encode("abcd").unwrap();
assert_eq!(result, vec!["abcd"]);
}
@@ -549,35 +542,41 @@ mod tests {
("qr".to_string(), -0.5),
];
let mut model = Unigram::from(sentencepieces, 0).unwrap();
let mut model = Unigram::from(sentencepieces, Some(0)).unwrap();
for is_optimized in &[true, false] {
model.set_optimized(*is_optimized);
println!("IsOptimized {:?}", is_optimized);
assert_eq!(model.encode("abc"), vec!["abc"]);
assert_eq!(model.encode("AB"), vec!["AB"]);
assert_eq!(model.encode("abc").unwrap(), vec!["abc"]);
assert_eq!(model.encode("AB").unwrap(), vec!["AB"]);
model.set_fuse_unk(false);
assert_eq!(model.encode("AB"), vec!["A", "B"]);
assert_eq!(model.encode("AB").unwrap(), vec!["A", "B"]);
model.set_fuse_unk(true);
assert_eq!(model.encode("AB"), vec!["AB"]);
assert_eq!(model.encode("AB").unwrap(), vec!["AB"]);
assert_eq!(model.encode("abcd"), vec!["ab", "cd"]);
assert_eq!(model.encode("abcc"), vec!["abc", "c"]);
assert_eq!(model.encode("abcd").unwrap(), vec!["ab", "cd"]);
assert_eq!(model.encode("abcc").unwrap(), vec!["abc", "c"]);
assert_eq!(
model.encode("xabcabaabcdd"),
model.encode("xabcabaabcdd").unwrap(),
vec!["x", "abc", "ab", "a", "ab", "cd", "d"]
);
model.set_fuse_unk(false);
assert_eq!(model.encode("xyz東京"), vec!["x", "y", "z", "", ""]);
assert_eq!(
model.encode("xyz東京").unwrap(),
vec!["x", "y", "z", "", ""]
);
model.set_fuse_unk(true);
assert_eq!(model.encode("xyz東京"), vec!["xyz東京"]);
assert_eq!(model.encode("xyz東京").unwrap(), vec!["xyz東京"]);
// User encoded in original version
assert_eq!(model.encode("ABC"), vec!["ABC"]);
assert_eq!(model.encode("abABCcd"), vec!["ab", "ABC", "cd"]);
assert_eq!(model.encode("ababcdabcdcd"), vec!["ab", "abcdabcd", "cd"]);
assert_eq!(model.encode("abqrcd"), vec!["ab", "q", "r", "cd"]);
assert_eq!(model.encode("ABC").unwrap(), vec!["ABC"]);
assert_eq!(model.encode("abABCcd").unwrap(), vec!["ab", "ABC", "cd"]);
assert_eq!(
model.encode("ababcdabcdcd").unwrap(),
vec!["ab", "abcdabcd", "cd"]
);
assert_eq!(model.encode("abqrcd").unwrap(), vec!["ab", "q", "r", "cd"]);
}
}
}

View File

@@ -52,11 +52,9 @@ impl<'de> Visitor<'de> for UnigramVisitor {
}
}
match (vocab, unk_id) {
(Some(vocab), Some(unk_id)) => Ok(Unigram::from(vocab, unk_id)
(Some(vocab), unk_id) => Ok(Unigram::from(vocab, unk_id)
.map_err(|err| Error::custom(&format!("Unable to load vocab {:?}", err)))?),
(None, Some(_)) => Err(Error::custom("Missing vocab")),
(None, None) => Err(Error::custom("Missing vocab and unk_id")),
(Some(_), None) => Err(Error::custom("Missing unk_id")),
(None, _) => Err(Error::custom("Missing vocab")),
}
}
}
@@ -68,7 +66,7 @@ mod test {
#[test]
fn test_serialization() {
let vocab = vec![("<unk>".to_string(), 0.0), ("a".to_string(), -0.5)];
let model = Unigram::from(vocab, 0).unwrap();
let model = Unigram::from(vocab, Some(0)).unwrap();
let data = serde_json::to_string(&model).unwrap();
let reconstructed = serde_json::from_str(&data).unwrap();
@@ -79,7 +77,18 @@ mod test {
#[test]
fn test_serialization_unk_id_not_zero() {
let vocab = vec![("a".to_string(), -0.5), ("<unk>".to_string(), 0.0)];
let model = Unigram::from(vocab, 1).unwrap();
let model = Unigram::from(vocab, Some(1)).unwrap();
let data = serde_json::to_string(&model).unwrap();
let reconstructed = serde_json::from_str(&data).unwrap();
assert_eq!(model, reconstructed);
}
#[test]
fn test_serialization_no_unk_id() {
let vocab = vec![("a".to_string(), -0.5)];
let model = Unigram::from(vocab, None).unwrap();
let data = serde_json::to_string(&model).unwrap();
let reconstructed = serde_json::from_str(&data).unwrap();

View File

@@ -51,8 +51,8 @@ pub struct UnigramTrainer {
#[builder(default = "HashSet::new()")]
initial_alphabet: HashSet<char>,
#[builder(default = "String::from(\"<unk>\")")]
unk_token: String,
#[builder(default = "None")]
unk_token: Option<String>,
#[builder(default = "16")]
max_piece_length: usize,
@@ -122,7 +122,33 @@ impl UnigramTrainer {
}
}
pieces.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap());
Unigram::from(pieces, 0)
// Insert the necessary tokens
let (unk_id, need_add_unk) = if let Some(ref unk) = self.unk_token {
let unk_id = self.special_tokens.iter().enumerate().find_map(|(i, t)| {
if t.content == *unk {
Some(i)
} else {
None
}
});
match unk_id {
Some(id) => (Some(id), false),
None => (Some(0), true),
}
} else {
(None, false)
};
let mut special_tokens = self
.special_tokens
.iter()
.map(|t| (t.content.clone(), 0.0))
.collect::<Vec<_>>();
if need_add_unk {
special_tokens.insert(0, (self.unk_token.clone().unwrap(), 0.0));
}
Unigram::from(special_tokens.into_iter().chain(pieces).collect(), unk_id)
}
fn required_chars(&self, word_counts: &[Sentence]) -> HashSet<String> {
@@ -230,7 +256,7 @@ impl UnigramTrainer {
always_keep[id] = false;
continue;
}
let mut lattice = Lattice::from(token, 0, bos_id, eos_id);
let mut lattice = Lattice::from(token, bos_id, eos_id);
model.populate_nodes(&mut lattice);
let nbests = lattice.nbest(2);
@@ -255,7 +281,7 @@ impl UnigramTrainer {
let mut inverted: Vec<Vec<usize>> = vec![Vec::new(); pieces.len()];
// TODO reparallelize this
for (i, (sentence, count)) in sentences.iter().enumerate() {
let mut lattice = Lattice::from(sentence, 0, bos_id, eos_id);
let mut lattice = Lattice::from(sentence, bos_id, eos_id);
model.populate_nodes(&mut lattice);
vsum += *count as f64;
for node_ref in lattice.viterbi() {
@@ -365,7 +391,7 @@ impl UnigramTrainer {
// TODO reparallelize this.
for (string, freq) in sentences {
let mut lattice = Lattice::from(string, model.unk_id, model.bos_id, model.eos_id);
let mut lattice = Lattice::from(string, model.bos_id, model.eos_id);
model.populate_nodes(&mut lattice);
let z: f64 = lattice.populate_marginal(*freq as f64, &mut expected);
ntokens += lattice.viterbi().len() as u32;
@@ -422,13 +448,7 @@ impl UnigramTrainer {
self.update_progress(&progress, sentences.len(), "Suffix array seeds");
let mut pieces: Vec<SentencePiece> =
Vec::with_capacity(self.vocab_size.try_into().unwrap());
// XXX: Make sure unk exists and are ids 0
pieces.push((self.unk_token.clone(), f64::NAN));
pieces.extend(
self.special_tokens
.iter()
.map(|tok| (tok.content.clone(), f64::NAN)),
);
pieces.extend(self.make_seed_sentence_pieces(&sentences, &progress)?);
self.finalize_progress(&progress, sentences.len());
@@ -452,7 +472,7 @@ impl UnigramTrainer {
let expected_updates = expected_loops as usize * self.n_sub_iterations as usize;
self.update_progress(&progress, expected_updates, "EM training");
let required_chars = self.required_chars(&sentences);
let mut model = Unigram::from(pieces.clone(), 0)?;
let mut model = Unigram::from(pieces.clone(), None)?;
loop {
// Sub-EM iteration.
for _iter in 0..self.n_sub_iterations {
@@ -461,7 +481,7 @@ impl UnigramTrainer {
// Executes M step.
pieces = self.run_m_step(&pieces, &expected);
model = Unigram::from(pieces.clone(), 0)?;
model = Unigram::from(pieces.clone(), None)?;
// Useful comment for checking compatibility with spm
debug!(
@@ -485,7 +505,7 @@ impl UnigramTrainer {
// Prunes pieces.
pieces = self.prune_sentence_pieces(&model, &pieces, &sentences);
model = Unigram::from(pieces.clone(), 0)?;
model = Unigram::from(pieces.clone(), None)?;
}
self.finalize_progress(&progress, expected_updates);
@@ -598,6 +618,72 @@ mod tests {
);
}
#[test]
fn test_unk_token() {
// 1. Should add `unk_token` as first special token
let trainer = UnigramTrainerBuilder::default()
.show_progress(false)
.special_tokens(vec![
AddedToken::from("[SEP]", true),
AddedToken::from("[CLS]", true),
])
.unk_token(Some("[UNK]".into()))
.build()
.unwrap();
let (unigram, _) = trainer
.train(HashMap::from_iter(vec![
("The".into(), 12),
("are".into(), 11),
]))
.unwrap();
let mut pieces = unigram.iter();
assert_eq!(pieces.next(), Some(&("[UNK]".into(), 0.0)));
assert_eq!(pieces.next(), Some(&("[SEP]".into(), 0.0)));
assert_eq!(pieces.next(), Some(&("[CLS]".into(), 0.0)));
// 2. Let it where it is
let trainer = UnigramTrainerBuilder::default()
.show_progress(false)
.special_tokens(vec![
AddedToken::from("[SEP]", true),
AddedToken::from("[CLS]", true),
AddedToken::from("[UNK]", true),
])
.unk_token(Some("[UNK]".into()))
.build()
.unwrap();
let (unigram, _) = trainer
.train(HashMap::from_iter(vec![
("The".into(), 12),
("are".into(), 11),
]))
.unwrap();
let mut pieces = unigram.iter();
assert_eq!(pieces.next(), Some(&("[SEP]".into(), 0.0)));
assert_eq!(pieces.next(), Some(&("[CLS]".into(), 0.0)));
assert_eq!(pieces.next(), Some(&("[UNK]".into(), 0.0)));
// 3. Don't put it there if not needed
let trainer = UnigramTrainerBuilder::default()
.show_progress(false)
.build()
.unwrap();
let (unigram, _) = trainer
.train(HashMap::from_iter(vec![
("The".into(), 12),
("are".into(), 11),
]))
.unwrap();
let mut pieces = unigram.iter();
assert_eq!(pieces.next().unwrap().0, "e".to_string());
}
#[test]
fn test_special_tokens() {
let trainer = UnigramTrainerBuilder::default()
@@ -617,7 +703,6 @@ mod tests {
.unwrap();
let mut pieces = unigram.iter();
assert_eq!(pieces.next(), Some(&("<unk>".into(), 0.0)));
assert_eq!(pieces.next(), Some(&("[SEP]".into(), 0.0)));
assert_eq!(pieces.next(), Some(&("[CLS]".into(), 0.0)));
}

View File

@@ -52,10 +52,11 @@ fn test_train_unigram_from_file() {
let trainer = UnigramTrainer::builder()
.show_progress(false)
.unk_token(Some("<unk>".into()))
.build()
.unwrap();
let (model, _) = trainer.train(word_counts).unwrap();
assert_eq!(model.get_vocab_size(), 719);
assert_eq!(model.get_vocab_size(), 717);
}
#[cfg(not(debug_assertions))]