From 73b5da917f152c77a9a6d7e99799e6624f6c88c9 Mon Sep 17 00:00:00 2001 From: Anthony MOI Date: Fri, 23 Oct 2020 12:18:57 -0400 Subject: [PATCH] Unigram - Add special_tokens at the end of training + optional unk --- bindings/python/src/models.rs | 3 +- .../python/tests/bindings/test_trainers.py | 50 ++++++ tokenizers/src/models/unigram/lattice.rs | 22 ++- tokenizers/src/models/unigram/model.rs | 153 +++++++++--------- .../src/models/unigram/serialization.rs | 21 ++- tokenizers/src/models/unigram/trainer.rs | 119 ++++++++++++-- tokenizers/tests/unigram.rs | 3 +- 7 files changed, 256 insertions(+), 115 deletions(-) diff --git a/bindings/python/src/models.rs b/bindings/python/src/models.rs index f522cc58..c06bef23 100644 --- a/bindings/python/src/models.rs +++ b/bindings/python/src/models.rs @@ -416,9 +416,8 @@ pub struct PyUnigram {} impl PyUnigram { #[new] fn new(vocab: Option>, unk_id: Option) -> PyResult<(Self, PyModel)> { - if vocab.is_some() && unk_id.is_none() || vocab.is_none() && unk_id.is_some() {} match (vocab, unk_id) { - (Some(vocab), Some(unk_id)) => { + (Some(vocab), unk_id) => { let model = Unigram::from(vocab, unk_id).map_err(|e| { exceptions::PyException::new_err(format!("Error while loading Unigram: {}", e)) })?; diff --git a/bindings/python/tests/bindings/test_trainers.py b/bindings/python/tests/bindings/test_trainers.py index 8a128589..98dc6600 100644 --- a/bindings/python/tests/bindings/test_trainers.py +++ b/bindings/python/tests/bindings/test_trainers.py @@ -42,3 +42,53 @@ class TestUnigram: trainer = trainers.BpeTrainer(special_tokens=[""], show_progress=False) bpe_tokenizer.train(trainer, [train_files["small"]]) + + def test_train_with_special_tokens(self): + filename = "tests/data/dummy-unigram-special_tokens-train.txt" + with open(filename, "w") as f: + f.write( + """ +[CLS] The Zen of Python, by Tim Peters [SEP] +[CLS] Beautiful is better than ugly. [SEP] +[CLS] Explicit is better than implicit. [SEP] +[CLS] Simple is better than complex. [SEP] +[CLS] Complex is better than complicated. [SEP] +[CLS] Flat is better than nested. [SEP] +[CLS] Sparse is better than dense. [SEP] +[CLS] Readability counts. [SEP] +[CLS] Special cases aren't special enough to break the rules. [SEP] +[CLS] Although practicality beats purity. [SEP] +[CLS] Errors should never pass silently. [SEP] +[CLS] Unless explicitly silenced. [SEP] +[CLS] In the face of ambiguity, refuse the temptation to guess. [SEP] +[CLS] There should be one-- and preferably only one --obvious way to do it. [SEP] +[CLS] Although that way may not be obvious at first unless you're Dutch. [SEP] +[CLS] Now is better than never. [SEP] +[CLS] Although never is often better than *right* now. [SEP] +[CLS] If the implementation is hard to explain, it's a bad idea. [SEP] +[CLS] If the implementation is easy to explain, it may be a good idea. [SEP] +[CLS] Namespaces are one honking great idea -- let's do more of those! [SEP] + """ + ) + + tokenizer = Tokenizer(models.Unigram()) + trainer = trainers.UnigramTrainer( + show_progress=False, special_tokens=["[PAD]", "[SEP]", "[CLS]"], unk_token="[UNK]" + ) + + tokenizer.train(trainer, [filename]) + + assert tokenizer.encode("[CLS] This is a test [SEP]").tokens == [ + "[CLS]", + " T", + "h", + "i", + "s", + " is ", + "a", + " ", + "t", + "es", + "t ", + "[SEP]", + ] diff --git a/tokenizers/src/models/unigram/lattice.rs b/tokenizers/src/models/unigram/lattice.rs index a8dbe208..98b03c3a 100644 --- a/tokenizers/src/models/unigram/lattice.rs +++ b/tokenizers/src/models/unigram/lattice.rs @@ -58,7 +58,6 @@ pub struct Lattice<'a> { pub(super) end_nodes: Vec>, bos_id: usize, eos_id: usize, - unk_id: usize, } impl std::fmt::Display for Lattice<'_> { @@ -136,7 +135,7 @@ fn log_sum_exp(x: f64, y: f64, init_mode: bool) -> f64 { } impl<'a> Lattice<'a> { - pub fn from(sentence: &'a str, unk_id: usize, bos_id: usize, eos_id: usize) -> Lattice<'a> { + pub fn from(sentence: &'a str, bos_id: usize, eos_id: usize) -> Lattice<'a> { let len = sentence.bytes().count(); let k_reserved_node_size = 16; // We are adding 2 tokens, bos and eos @@ -161,7 +160,6 @@ impl<'a> Lattice<'a> { end_nodes, bos_id, eos_id, - unk_id, } } @@ -439,16 +437,16 @@ mod tests { #[test] fn set_sentence() { - let lattice = Lattice::from("", 0, 1, 2); + let lattice = Lattice::from("", 1, 2); assert_eq!(lattice.len(), 0); - let lattice = Lattice::from("", 0, 1, 2); + let lattice = Lattice::from("", 1, 2); assert_eq!(lattice.len(), 0); assert_eq!(lattice.sentence(), ""); assert_eq!(lattice.surface(0), ""); - let lattice = Lattice::from("test", 0, 1, 2); + let lattice = Lattice::from("test", 1, 2); assert_eq!(lattice.len(), 4); assert_eq!(lattice.sentence(), "test"); assert_eq!(lattice.surface(0), "test"); @@ -470,7 +468,7 @@ mod tests { eos.borrow().id ); - let lattice = Lattice::from("テストab", 0, 1, 2); + let lattice = Lattice::from("テストab", 1, 2); assert_eq!(lattice.len(), 11); assert_eq!(lattice.sentence(), "テストab"); assert_eq!(lattice.surface(0), "テストab"); @@ -482,7 +480,7 @@ mod tests { #[test] fn insert_test() { - let mut lattice = Lattice::from("ABあい", 0, 1, 2); + let mut lattice = Lattice::from("ABあい", 1, 2); lattice.insert(0, 1, 0.0, 3); lattice.insert(1, 1, 0.0, 4); @@ -573,7 +571,7 @@ mod tests { #[test] fn test_viterbi() { - let mut lattice = Lattice::from("ABC", 0, 1, 2); + let mut lattice = Lattice::from("ABC", 1, 2); assert_eq!(lattice.viterbi(), vec![]); // Still incomplete lattice.insert(0, 1, 0.0, 3); @@ -586,7 +584,7 @@ mod tests { #[test] fn test_viterbi2() { - let mut lattice = Lattice::from("ABC", 0, 1, 2); + let mut lattice = Lattice::from("ABC", 1, 2); lattice.insert(0, 1, 0.0, 3); lattice.insert(1, 1, 0.0, 4); @@ -606,7 +604,7 @@ mod tests { #[test] fn test_nbest() { - let mut lattice = Lattice::from("ABC", 0, 1, 2); + let mut lattice = Lattice::from("ABC", 1, 2); lattice.insert(0, 1, 0.0, 3); lattice.insert(1, 1, 0.0, 4); lattice.insert(2, 1, 0.0, 5); @@ -641,7 +639,7 @@ mod tests { #[test] fn test_populate() { - let mut lattice = Lattice::from("ABC", 0, 1, 2); + let mut lattice = Lattice::from("ABC", 1, 2); lattice.insert(0, 1, 1.0, 3); // A lattice.insert(1, 1, 1.2, 4); // B lattice.insert(2, 1, 2.5, 5); // C diff --git a/tokenizers/src/models/unigram/model.rs b/tokenizers/src/models/unigram/model.rs index e57455c2..ea6e5042 100644 --- a/tokenizers/src/models/unigram/model.rs +++ b/tokenizers/src/models/unigram/model.rs @@ -18,7 +18,7 @@ pub struct Unigram { cache: Cache>, trie: Trie, pub min_score: f64, - pub(super) unk_id: usize, + pub(super) unk_id: Option, pub(super) bos_id: usize, pub(super) eos_id: usize, @@ -54,7 +54,7 @@ impl Clone for Unigram { impl std::fmt::Debug for Unigram { fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { fmt.debug_struct("Unigram") - .field("vocab", &self.vocab) + .field("vocab", &self.vocab.len()) .field("unk_id", &self.unk_id) .finish() } @@ -66,6 +66,7 @@ static K_UNK_PENALTY: f64 = 10.0; pub enum UnigramError { EmptyVocabulary, UnkIdNotInVocabulary, + MissingUnkId, } impl std::fmt::Display for UnigramError { @@ -77,6 +78,9 @@ impl std::fmt::Display for UnigramError { UnigramError::UnkIdNotInVocabulary => { write!(f, "The `unk_id` is larger than vocabulary size") } + UnigramError::MissingUnkId => { + write!(f, "Encountered an unknown token but `unk_id` is missing") + } } } } @@ -86,7 +90,7 @@ impl std::error::Error for UnigramError {} impl Default for Unigram { fn default() -> Self { let vocab = vec![("".to_string(), 0.0)]; - Self::from(vocab, 0).unwrap() + Self::from(vocab, Some(0)).unwrap() } } @@ -97,16 +101,18 @@ impl Unigram { /// unk_id, is the index within the vocabulary. /// For now `Unigram` *requires* at least `unk` because we might find a never seen char. /// Further versions might allow that part to be hidden. - pub fn from(vocab: Vec<(String, f64)>, unk_id: usize) -> Result { + pub fn from(vocab: Vec<(String, f64)>, unk_id: Option) -> Result { let n = vocab.len(); let mut token_to_ids: TokenMap = HashMap::new(); let mut builder = TrieBuilder::default(); - if vocab.is_empty() { - return Err(Box::new(UnigramError::EmptyVocabulary)); - } - if unk_id >= vocab.len() { - return Err(Box::new(UnigramError::UnkIdNotInVocabulary)); + if let Some(unk_id) = unk_id { + if vocab.is_empty() { + return Err(Box::new(UnigramError::EmptyVocabulary)); + } + if unk_id >= vocab.len() { + return Err(Box::new(UnigramError::UnkIdNotInVocabulary)); + } } let bos_id = n + 1; @@ -187,7 +193,9 @@ impl Unigram { } if !has_single_node { - lattice.insert(begin_pos, mblen, unk_score, self.unk_id); + if let Some(unk_id) = self.unk_id { + lattice.insert(begin_pos, mblen, unk_score, unk_id); + } } begin_pos += mblen } @@ -209,28 +217,28 @@ impl Unigram { /// ("abc".to_string(), 5.0), /// ("abcd".to_string(), 10.0), /// ]; - /// let model = Unigram::from(pieces, 0).unwrap(); - /// let result = model.encode("abcdacdxx"); + /// let model = Unigram::from(pieces, Some(0)).unwrap(); + /// let result = model.encode("abcdacdxx").unwrap(); /// assert_eq!(result, vec!["abcd", "a", "cd", "xx"]); /// ``` - pub fn encode(&self, sentence: &str) -> Vec { + pub fn encode(&self, sentence: &str) -> Result> { if sentence.is_empty() { - return vec![]; + return Ok(vec![]); } if let Some(result) = self.cache.get(sentence) { - result.to_vec() + Ok(result.to_vec()) } else { let result = if self.is_optimized { - self.encode_optimized(sentence) + self.encode_optimized(sentence)? } else { - self.encode_unoptimized(sentence) + self.encode_unoptimized(sentence)? }; self.cache.set(sentence.to_owned(), result.clone()); - result + Ok(result) } } - fn encode_optimized(&self, sentence: &str) -> Vec { + fn encode_optimized(&self, sentence: &str) -> Result> { // https://github.com/google/sentencepiece/blob/d48247191a6d50e469ed1a4a36e877befffd1851/src/unigram_model.cc#L600 #[derive(Debug, Clone)] struct BestPathNode { @@ -290,7 +298,7 @@ impl Unigram { { target_node.best_path_score = candidate_best_path_score; target_node.starts_at = Some(starts_at); - target_node.id = self.unk_id; + target_node.id = self.unk_id.ok_or(UnigramError::MissingUnkId)?; } } starts_at += mblen @@ -301,16 +309,9 @@ impl Unigram { while ends_at > 0 { let node = &best_path_ends_at[ends_at]; let starts_at = node.starts_at.unwrap(); - if self.fuse_unk && node.id == self.unk_id { + if self.fuse_unk && node.id == self.unk_id.ok_or(UnigramError::MissingUnkId)? { token.push( - String::from_utf8( - sentence - .bytes() - .skip(starts_at) - .take(ends_at - starts_at) - .collect(), - ) - .unwrap(), + String::from_utf8(sentence[starts_at..ends_at].as_bytes().to_vec()).unwrap(), ); } else { if !token.is_empty() { @@ -319,14 +320,7 @@ impl Unigram { token = vec![]; } results.push( - String::from_utf8( - sentence - .bytes() - .skip(starts_at) - .take(ends_at - starts_at) - .collect(), - ) - .unwrap(), + String::from_utf8(sentence[starts_at..ends_at].as_bytes().to_vec()).unwrap(), ); } ends_at = starts_at; @@ -336,18 +330,18 @@ impl Unigram { results.push(token.concat()); } results.reverse(); - results + Ok(results) } - fn encode_unoptimized(&self, sentence: &str) -> Vec { - let mut lattice = Lattice::from(sentence, self.unk_id, self.bos_id, self.eos_id); + fn encode_unoptimized(&self, sentence: &str) -> Result> { + let mut lattice = Lattice::from(sentence, self.bos_id, self.eos_id); self.populate_nodes(&mut lattice); if self.fuse_unk { let mut results = vec![]; let mut token = String::new(); for node in lattice.viterbi().iter() { let item = lattice.piece(&node.borrow()); - if node.borrow().id == self.unk_id { + if node.borrow().id == self.unk_id.ok_or(UnigramError::MissingUnkId)? { token.push_str(&item); } else { if !token.is_empty() { @@ -360,9 +354,9 @@ impl Unigram { if !token.is_empty() { results.push(token); } - results + Ok(results) } else { - lattice.tokens() + Ok(lattice.tokens()) } } @@ -416,21 +410,20 @@ impl Model for Unigram { } fn tokenize(&self, sentence: &str) -> Result> { - let tokens = self.encode(sentence); + let str_tokens = self.encode(sentence)?; let mut offset = 0; - Ok(tokens - .iter() - .map(|string| { - let id: u32 = match self.token_to_ids.get(string) { - Some(id) => *id, - None => self.unk_id as u32, - }; - let len = string.len(); - let offsets = (offset, offset + len); - offset += len; - Token::new(id, string.to_string(), offsets) - }) - .collect()) + let mut tokens = Vec::with_capacity(str_tokens.len()); + for string in str_tokens { + let id: u32 = match self.token_to_ids.get(&string) { + Some(id) => *id, + None => self.unk_id.ok_or(UnigramError::MissingUnkId)? as u32, + }; + let len = string.len(); + let offsets = (offset, offset + len); + offset += len; + tokens.push(Token::new(id, string, offsets)); + } + Ok(tokens) } fn token_to_id(&self, token: &str) -> Option { @@ -465,9 +458,9 @@ mod tests { #[test] fn test_populate_nodes_unk() { let pieces = vec![("".to_string(), 0.0)]; - let model = Unigram::from(pieces, 0).unwrap(); + let model = Unigram::from(pieces, Some(0)).unwrap(); - let mut lattice = Lattice::from("abc", 0, model.bos_id, model.eos_id); + let mut lattice = Lattice::from("abc", model.bos_id, model.eos_id); model.populate_nodes(&mut lattice); assert_eq!(lattice.begin_nodes[0].len(), 1); @@ -490,9 +483,9 @@ mod tests { ("ab".to_string(), 0.3), ("bc".to_string(), 0.4), ]; - let model = Unigram::from(pieces, 0).unwrap(); + let model = Unigram::from(pieces, Some(0)).unwrap(); - let mut lattice = Lattice::from("abc", 0, model.bos_id, model.eos_id); + let mut lattice = Lattice::from("abc", model.bos_id, model.eos_id); model.populate_nodes(&mut lattice); assert_eq!(lattice.begin_nodes[0].len(), 2); // a, ab @@ -527,8 +520,8 @@ mod tests { ("abcd".to_string(), 10.0), ]; - let model = Unigram::from(sentencepieces, 0).unwrap(); - let result = model.encode("abcd"); + let model = Unigram::from(sentencepieces, Some(0)).unwrap(); + let result = model.encode("abcd").unwrap(); assert_eq!(result, vec!["abcd"]); } @@ -549,35 +542,41 @@ mod tests { ("qr".to_string(), -0.5), ]; - let mut model = Unigram::from(sentencepieces, 0).unwrap(); + let mut model = Unigram::from(sentencepieces, Some(0)).unwrap(); for is_optimized in &[true, false] { model.set_optimized(*is_optimized); println!("IsOptimized {:?}", is_optimized); - assert_eq!(model.encode("abc"), vec!["abc"]); - assert_eq!(model.encode("AB"), vec!["AB"]); + assert_eq!(model.encode("abc").unwrap(), vec!["abc"]); + assert_eq!(model.encode("AB").unwrap(), vec!["AB"]); model.set_fuse_unk(false); - assert_eq!(model.encode("AB"), vec!["A", "B"]); + assert_eq!(model.encode("AB").unwrap(), vec!["A", "B"]); model.set_fuse_unk(true); - assert_eq!(model.encode("AB"), vec!["AB"]); + assert_eq!(model.encode("AB").unwrap(), vec!["AB"]); - assert_eq!(model.encode("abcd"), vec!["ab", "cd"]); - assert_eq!(model.encode("abcc"), vec!["abc", "c"]); + assert_eq!(model.encode("abcd").unwrap(), vec!["ab", "cd"]); + assert_eq!(model.encode("abcc").unwrap(), vec!["abc", "c"]); assert_eq!( - model.encode("xabcabaabcdd"), + model.encode("xabcabaabcdd").unwrap(), vec!["x", "abc", "ab", "a", "ab", "cd", "d"] ); model.set_fuse_unk(false); - assert_eq!(model.encode("xyz東京"), vec!["x", "y", "z", "東", "京"]); + assert_eq!( + model.encode("xyz東京").unwrap(), + vec!["x", "y", "z", "東", "京"] + ); model.set_fuse_unk(true); - assert_eq!(model.encode("xyz東京"), vec!["xyz東京"]); + assert_eq!(model.encode("xyz東京").unwrap(), vec!["xyz東京"]); // User encoded in original version - assert_eq!(model.encode("ABC"), vec!["ABC"]); - assert_eq!(model.encode("abABCcd"), vec!["ab", "ABC", "cd"]); - assert_eq!(model.encode("ababcdabcdcd"), vec!["ab", "abcdabcd", "cd"]); - assert_eq!(model.encode("abqrcd"), vec!["ab", "q", "r", "cd"]); + assert_eq!(model.encode("ABC").unwrap(), vec!["ABC"]); + assert_eq!(model.encode("abABCcd").unwrap(), vec!["ab", "ABC", "cd"]); + assert_eq!( + model.encode("ababcdabcdcd").unwrap(), + vec!["ab", "abcdabcd", "cd"] + ); + assert_eq!(model.encode("abqrcd").unwrap(), vec!["ab", "q", "r", "cd"]); } } } diff --git a/tokenizers/src/models/unigram/serialization.rs b/tokenizers/src/models/unigram/serialization.rs index cc9a51a2..95a8686b 100644 --- a/tokenizers/src/models/unigram/serialization.rs +++ b/tokenizers/src/models/unigram/serialization.rs @@ -52,11 +52,9 @@ impl<'de> Visitor<'de> for UnigramVisitor { } } match (vocab, unk_id) { - (Some(vocab), Some(unk_id)) => Ok(Unigram::from(vocab, unk_id) + (Some(vocab), unk_id) => Ok(Unigram::from(vocab, unk_id) .map_err(|err| Error::custom(&format!("Unable to load vocab {:?}", err)))?), - (None, Some(_)) => Err(Error::custom("Missing vocab")), - (None, None) => Err(Error::custom("Missing vocab and unk_id")), - (Some(_), None) => Err(Error::custom("Missing unk_id")), + (None, _) => Err(Error::custom("Missing vocab")), } } } @@ -68,7 +66,7 @@ mod test { #[test] fn test_serialization() { let vocab = vec![("".to_string(), 0.0), ("a".to_string(), -0.5)]; - let model = Unigram::from(vocab, 0).unwrap(); + let model = Unigram::from(vocab, Some(0)).unwrap(); let data = serde_json::to_string(&model).unwrap(); let reconstructed = serde_json::from_str(&data).unwrap(); @@ -79,7 +77,18 @@ mod test { #[test] fn test_serialization_unk_id_not_zero() { let vocab = vec![("a".to_string(), -0.5), ("".to_string(), 0.0)]; - let model = Unigram::from(vocab, 1).unwrap(); + let model = Unigram::from(vocab, Some(1)).unwrap(); + + let data = serde_json::to_string(&model).unwrap(); + let reconstructed = serde_json::from_str(&data).unwrap(); + + assert_eq!(model, reconstructed); + } + + #[test] + fn test_serialization_no_unk_id() { + let vocab = vec![("a".to_string(), -0.5)]; + let model = Unigram::from(vocab, None).unwrap(); let data = serde_json::to_string(&model).unwrap(); let reconstructed = serde_json::from_str(&data).unwrap(); diff --git a/tokenizers/src/models/unigram/trainer.rs b/tokenizers/src/models/unigram/trainer.rs index 5efb715a..96436a76 100644 --- a/tokenizers/src/models/unigram/trainer.rs +++ b/tokenizers/src/models/unigram/trainer.rs @@ -51,8 +51,8 @@ pub struct UnigramTrainer { #[builder(default = "HashSet::new()")] initial_alphabet: HashSet, - #[builder(default = "String::from(\"\")")] - unk_token: String, + #[builder(default = "None")] + unk_token: Option, #[builder(default = "16")] max_piece_length: usize, @@ -122,7 +122,33 @@ impl UnigramTrainer { } } pieces.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap()); - Unigram::from(pieces, 0) + + // Insert the necessary tokens + let (unk_id, need_add_unk) = if let Some(ref unk) = self.unk_token { + let unk_id = self.special_tokens.iter().enumerate().find_map(|(i, t)| { + if t.content == *unk { + Some(i) + } else { + None + } + }); + match unk_id { + Some(id) => (Some(id), false), + None => (Some(0), true), + } + } else { + (None, false) + }; + let mut special_tokens = self + .special_tokens + .iter() + .map(|t| (t.content.clone(), 0.0)) + .collect::>(); + if need_add_unk { + special_tokens.insert(0, (self.unk_token.clone().unwrap(), 0.0)); + } + + Unigram::from(special_tokens.into_iter().chain(pieces).collect(), unk_id) } fn required_chars(&self, word_counts: &[Sentence]) -> HashSet { @@ -230,7 +256,7 @@ impl UnigramTrainer { always_keep[id] = false; continue; } - let mut lattice = Lattice::from(token, 0, bos_id, eos_id); + let mut lattice = Lattice::from(token, bos_id, eos_id); model.populate_nodes(&mut lattice); let nbests = lattice.nbest(2); @@ -255,7 +281,7 @@ impl UnigramTrainer { let mut inverted: Vec> = vec![Vec::new(); pieces.len()]; // TODO reparallelize this for (i, (sentence, count)) in sentences.iter().enumerate() { - let mut lattice = Lattice::from(sentence, 0, bos_id, eos_id); + let mut lattice = Lattice::from(sentence, bos_id, eos_id); model.populate_nodes(&mut lattice); vsum += *count as f64; for node_ref in lattice.viterbi() { @@ -365,7 +391,7 @@ impl UnigramTrainer { // TODO reparallelize this. for (string, freq) in sentences { - let mut lattice = Lattice::from(string, model.unk_id, model.bos_id, model.eos_id); + let mut lattice = Lattice::from(string, model.bos_id, model.eos_id); model.populate_nodes(&mut lattice); let z: f64 = lattice.populate_marginal(*freq as f64, &mut expected); ntokens += lattice.viterbi().len() as u32; @@ -422,13 +448,7 @@ impl UnigramTrainer { self.update_progress(&progress, sentences.len(), "Suffix array seeds"); let mut pieces: Vec = Vec::with_capacity(self.vocab_size.try_into().unwrap()); - // XXX: Make sure unk exists and are ids 0 - pieces.push((self.unk_token.clone(), f64::NAN)); - pieces.extend( - self.special_tokens - .iter() - .map(|tok| (tok.content.clone(), f64::NAN)), - ); + pieces.extend(self.make_seed_sentence_pieces(&sentences, &progress)?); self.finalize_progress(&progress, sentences.len()); @@ -452,7 +472,7 @@ impl UnigramTrainer { let expected_updates = expected_loops as usize * self.n_sub_iterations as usize; self.update_progress(&progress, expected_updates, "EM training"); let required_chars = self.required_chars(&sentences); - let mut model = Unigram::from(pieces.clone(), 0)?; + let mut model = Unigram::from(pieces.clone(), None)?; loop { // Sub-EM iteration. for _iter in 0..self.n_sub_iterations { @@ -461,7 +481,7 @@ impl UnigramTrainer { // Executes M step. pieces = self.run_m_step(&pieces, &expected); - model = Unigram::from(pieces.clone(), 0)?; + model = Unigram::from(pieces.clone(), None)?; // Useful comment for checking compatibility with spm debug!( @@ -485,7 +505,7 @@ impl UnigramTrainer { // Prunes pieces. pieces = self.prune_sentence_pieces(&model, &pieces, &sentences); - model = Unigram::from(pieces.clone(), 0)?; + model = Unigram::from(pieces.clone(), None)?; } self.finalize_progress(&progress, expected_updates); @@ -598,6 +618,72 @@ mod tests { ); } + #[test] + fn test_unk_token() { + // 1. Should add `unk_token` as first special token + let trainer = UnigramTrainerBuilder::default() + .show_progress(false) + .special_tokens(vec![ + AddedToken::from("[SEP]", true), + AddedToken::from("[CLS]", true), + ]) + .unk_token(Some("[UNK]".into())) + .build() + .unwrap(); + + let (unigram, _) = trainer + .train(HashMap::from_iter(vec![ + ("The".into(), 12), + ("are".into(), 11), + ])) + .unwrap(); + + let mut pieces = unigram.iter(); + assert_eq!(pieces.next(), Some(&("[UNK]".into(), 0.0))); + assert_eq!(pieces.next(), Some(&("[SEP]".into(), 0.0))); + assert_eq!(pieces.next(), Some(&("[CLS]".into(), 0.0))); + + // 2. Let it where it is + let trainer = UnigramTrainerBuilder::default() + .show_progress(false) + .special_tokens(vec![ + AddedToken::from("[SEP]", true), + AddedToken::from("[CLS]", true), + AddedToken::from("[UNK]", true), + ]) + .unk_token(Some("[UNK]".into())) + .build() + .unwrap(); + + let (unigram, _) = trainer + .train(HashMap::from_iter(vec![ + ("The".into(), 12), + ("are".into(), 11), + ])) + .unwrap(); + + let mut pieces = unigram.iter(); + assert_eq!(pieces.next(), Some(&("[SEP]".into(), 0.0))); + assert_eq!(pieces.next(), Some(&("[CLS]".into(), 0.0))); + assert_eq!(pieces.next(), Some(&("[UNK]".into(), 0.0))); + + // 3. Don't put it there if not needed + let trainer = UnigramTrainerBuilder::default() + .show_progress(false) + .build() + .unwrap(); + + let (unigram, _) = trainer + .train(HashMap::from_iter(vec![ + ("The".into(), 12), + ("are".into(), 11), + ])) + .unwrap(); + + let mut pieces = unigram.iter(); + assert_eq!(pieces.next().unwrap().0, "e".to_string()); + } + #[test] fn test_special_tokens() { let trainer = UnigramTrainerBuilder::default() @@ -617,7 +703,6 @@ mod tests { .unwrap(); let mut pieces = unigram.iter(); - assert_eq!(pieces.next(), Some(&("".into(), 0.0))); assert_eq!(pieces.next(), Some(&("[SEP]".into(), 0.0))); assert_eq!(pieces.next(), Some(&("[CLS]".into(), 0.0))); } diff --git a/tokenizers/tests/unigram.rs b/tokenizers/tests/unigram.rs index 05c24a68..7bb5ef3f 100644 --- a/tokenizers/tests/unigram.rs +++ b/tokenizers/tests/unigram.rs @@ -52,10 +52,11 @@ fn test_train_unigram_from_file() { let trainer = UnigramTrainer::builder() .show_progress(false) + .unk_token(Some("".into())) .build() .unwrap(); let (model, _) = trainer.train(word_counts).unwrap(); - assert_eq!(model.get_vocab_size(), 719); + assert_eq!(model.get_vocab_size(), 717); } #[cfg(not(debug_assertions))]