diff --git a/bindings/python/py_src/tokenizers/models/__init__.pyi b/bindings/python/py_src/tokenizers/models/__init__.pyi index b9d76b0d..008bf375 100644 --- a/bindings/python/py_src/tokenizers/models/__init__.pyi +++ b/bindings/python/py_src/tokenizers/models/__init__.pyi @@ -110,13 +110,8 @@ class Unigram(Model): vocab: ('`optional`) string: Path to a vocabulary JSON file. - is_spm_file: ('`optional`) bool: - If the file came out of sentencepiece, we need to load it differently - """ @staticmethod - def __init__( - self, vocab: Optional[str], is_spm_file: Optional[bool], - ): + def __init__(self, vocab: Optional[str]): pass diff --git a/bindings/python/src/models.rs b/bindings/python/src/models.rs index c41948bc..ddb384ea 100644 --- a/bindings/python/src/models.rs +++ b/bindings/python/src/models.rs @@ -263,41 +263,15 @@ pub struct PyUnigram {} #[pymethods] impl PyUnigram { #[new] - #[args(kwargs = "**")] - fn new(vocab: Option<&str>, kwargs: Option<&PyDict>) -> PyResult<(Self, PyModel)> { - let mut is_spm_file = false; - if let Some(kwargs) = kwargs { - for (key, val) in kwargs { - let key: &str = key.extract()?; - match key { - "is_spm_file" => is_spm_file = val.extract()?, - _ => println!("Ignored unknown kwargs option {}", key), - } - } - } - + fn new(vocab: Option<&str>) -> PyResult<(Self, PyModel)> { if let Some(vocab) = vocab { let path = Path::new(vocab); - if is_spm_file { - match Unigram::load_spm(path) { - Err(e) => { - println!("Errors: {:?}", e); - Err(exceptions::Exception::py_err( - "Error while initializing Unigram from spm file", - )) - } - Ok(model) => Ok((PyUnigram {}, PyModel::new(Arc::new(model.into())))), - } - } else { - match Unigram::load(path) { - Err(e) => { - println!("Errors: {:?}", e); - Err(exceptions::Exception::py_err( - "Error while initializing Unigram", - )) - } - Ok(model) => Ok((PyUnigram {}, PyModel::new(Arc::new(model.into())))), + match Unigram::load(path) { + Err(e) => { + println!("Errors: {:?}", e); + Err(exceptions::Exception::py_err("Error while loading Unigram")) } + Ok(model) => Ok((PyUnigram {}, PyModel::new(Arc::new(model.into())))), } } else { Ok(( diff --git a/bindings/python/src/trainers.rs b/bindings/python/src/trainers.rs index 7711720d..092ab1fe 100644 --- a/bindings/python/src/trainers.rs +++ b/bindings/python/src/trainers.rs @@ -190,25 +190,7 @@ impl PyUnigramTrainer { "show_progress" => builder.show_progress(val.extract()?), "n_sub_iterations" => builder.n_sub_iterations(val.extract()?), "shrinking_factor" => builder.shrinking_factor(val.extract()?), - "space_char" => { - let string: String = val.extract()?; - if string.chars().collect::>().len() != 1 { - return Err(exceptions::Exception::py_err( - "space_char must be 1 unicode char long", - )); - } - builder.space_char(string.chars().next().ok_or_else(|| { - exceptions::Exception::py_err("space_char must not be 0 width") - })?) - } "unk_token" => builder.unk_token(val.extract()?), - "split_by_number" => builder.split_by_number(val.extract()?), - "treat_whitespace_as_suffix" => { - builder.treat_whitespace_as_suffix(val.extract()?) - } - "split_by_unicode_script" => builder.split_by_unicode_script(val.extract()?), - "split_by_digits" => builder.split_by_digits(val.extract()?), - "split_by_whitespace" => builder.split_by_whitespace(val.extract()?), "max_piece_length" => builder.max_piece_length(val.extract()?), "seed_size" => builder.seed_size(val.extract()?), "special_tokens" => builder.special_tokens( diff --git a/tokenizers/Cargo.toml b/tokenizers/Cargo.toml index f2308dd9..ae36bcbe 100644 --- a/tokenizers/Cargo.toml +++ b/tokenizers/Cargo.toml @@ -57,4 +57,3 @@ derive_builder = "0.9" criterion = "0.3" tempfile = "3.1" assert_approx_eq = "1.1" -unicode-normalization = "0.1" diff --git a/tokenizers/Makefile b/tokenizers/Makefile index 3d9c0f0a..16da3b81 100644 --- a/tokenizers/Makefile +++ b/tokenizers/Makefile @@ -6,7 +6,7 @@ dir_guard=@mkdir -p $(@D) SHARED_RESOURCES = $(DATA_DIR)/gpt2-vocab.json $(DATA_DIR)/gpt2-merges.txt $(DATA_DIR)/bert-base-uncased-vocab.txt BENCHMARK_RESOURCES = $(SHARED_RESOURCES) $(DATA_DIR)/big.txt $(DATA_DIR)/small.txt -TESTS_RESOURCES = $(SHARED_RESOURCES) $(DATA_DIR)/unigram.json $(DATA_DIR)/unigram.model $(DATA_DIR)/unigram_wagahaiwa_nekodearu.txt +TESTS_RESOURCES = $(SHARED_RESOURCES) $(DATA_DIR)/unigram.json $(DATA_DIR)/unigram_wagahaiwa_nekodearu.txt .PHONY : build build : diff --git a/tokenizers/src/models/unigram/model.rs b/tokenizers/src/models/unigram/model.rs index 470ae48e..88d664c9 100644 --- a/tokenizers/src/models/unigram/model.rs +++ b/tokenizers/src/models/unigram/model.rs @@ -5,18 +5,17 @@ use crate::tokenizer::{Model, Result, Token}; use std::collections::HashMap; use std::convert::TryInto; use std::fs::File; -use std::io::{BufRead, BufReader}; +use std::io::BufReader; use std::path::{Path, PathBuf}; type TokenMap = HashMap; -type Vocab = Vec; +type Vocab = Vec<(String, f64)>; /// A `Unigram` model to encode sentences. #[derive(Clone)] pub struct Unigram { token_to_ids: TokenMap, pub(crate) vocab: Vocab, - pub(super) scores: Vec, trie: Trie, pub min_score: f64, pub(super) unk_id: usize, @@ -27,10 +26,7 @@ pub struct Unigram { } impl PartialEq for Unigram { fn eq(&self, other: &Self) -> bool { - let vocab: Vec<(&String, &f64)> = self.vocab.iter().zip(self.scores.iter()).collect(); - let other_vocab: Vec<(&String, &f64)> = - other.vocab.iter().zip(other.scores.iter()).collect(); - self.unk_id == other.unk_id && vocab == other_vocab + self.unk_id == other.unk_id && self.vocab == other.vocab } } @@ -44,10 +40,31 @@ impl std::fmt::Debug for Unigram { static K_UNK_PENALTY: f64 = 10.0; +#[derive(Debug)] +pub enum UnigramError { + EmptyVocabulary, + UnkIdNotInVocabulary, +} + +impl std::fmt::Display for UnigramError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + UnigramError::EmptyVocabulary => { + write!(f, "The vocabulary is empty but at least is needed") + } + UnigramError::UnkIdNotInVocabulary => { + write!(f, "The `unk_id` is larger than vocabulary size") + } + } + } +} + +impl std::error::Error for UnigramError {} + impl Default for Unigram { fn default() -> Self { let vocab = vec![("".to_string(), 0.0)]; - Self::from(&vocab, 0) + Self::from(&vocab, 0).unwrap() } } @@ -58,36 +75,39 @@ impl Unigram { /// unk_id, is the index within the vocabulary. /// For now `Unigram` *requires* at least `unk` because we might find a never seen char. /// Further versions might allow that part to be hidden. - pub fn from(vocabulary: &[(String, f64)], unk_id: usize) -> Self { + pub fn from( + vocabulary: &[(String, f64)], + unk_id: usize, + ) -> std::result::Result { let n = vocabulary.len(); - let mut vocab: Vec = Vec::with_capacity(n); - let mut scores: Vec = Vec::with_capacity(n); + let vocab: Vec<(String, f64)> = vocabulary.iter().cloned().collect(); let mut token_to_ids: TokenMap = HashMap::new(); let mut builder = TrieBuilder::default(); - assert!(n >= 1, "We need at least unk in the vocabulary"); - assert!(unk_id < vocabulary.len(), "Unk id is invalid"); + if vocabulary.is_empty() { + return Err(UnigramError::EmptyVocabulary); + } + if unk_id >= vocabulary.len() { + return Err(UnigramError::UnkIdNotInVocabulary); + } let bos_id = n + 1; let eos_id = n + 2; + let mut min_score = f64::INFINITY; for (id, (token, score)) in vocabulary.iter().enumerate() { - vocab.push(token.to_string()); - scores.push(*score); token_to_ids.insert(token.to_string(), id as u32); let chars: Vec = token.chars().collect(); builder.push(&chars); - } - let min_score = scores.iter().fold(f64::INFINITY, |a, &b| a.min(b)); - if min_score == -f64::INFINITY { - panic!("Alert min_score !!"); + if score < &min_score { + min_score = *score; + } } let trie = builder.build(); let fuse_unk = true; - Unigram { + Ok(Unigram { vocab, - scores, token_to_ids, trie, min_score, @@ -95,7 +115,7 @@ impl Unigram { eos_id, unk_id, fuse_unk, - } + }) } #[cfg(test)] @@ -124,8 +144,10 @@ impl Unigram { let n = result.len(); let tok: String = result.into_iter().collect(); let id = *self.token_to_ids.get(&tok).unwrap(); - assert_eq!(self.vocab[id as usize], tok); - let score: f64 = self.scores[id as usize]; + + let item = &self.vocab[id as usize]; + assert_eq!(item.0, tok); + let score: f64 = item.1; lattice.insert(begin_pos, n, score, id.try_into().unwrap()); if !has_single_node && n == 1 { has_single_node = true; @@ -154,7 +176,7 @@ impl Unigram { /// ("abc".to_string(), 5.0), /// ("abcd".to_string(), 10.0), /// ]; - /// let model = Unigram::from(&pieces, 0); + /// let model = Unigram::from(&pieces, 0).unwrap(); /// let result = model.encode("abcdacdxx"); /// assert_eq!(result, vec!["abcd", "a", "cd", "xx"]); /// ``` @@ -187,43 +209,6 @@ impl Unigram { } } - /// Loads a SentencePiece output model. - /// In order to get the proper model with spm. - /// - /// ```ignore - /// spm_train --model=unigram --input=.... --model_prefix=myprefix ... - /// spm_export_vocab --model=myprefix.model --output=myprefix.txt - /// ``` - /// - /// After that you can use the model with tokenizers library. - /// ```no_run - /// use tokenizers::models::unigram::Unigram; - /// use std::path::Path; - /// - /// let model = Unigram::load_spm(Path::new("myprefix.txt")).unwrap(); - /// ``` - pub fn load_spm>(path: P) -> Result { - let file = BufReader::new(File::open(path)?); - let table = file - .lines() - .enumerate() - .map(|(i, line)| { - let line = line?; - let newline = line.replace('▁', " "); - let tokens: Vec<_> = newline.split('\t').collect(); - match tokens.as_slice() { - [token, score] => Ok((token.to_string(), score.parse()?)), - _ => Err(format!("Line {} is invalid {:?}", i, line).into()), - } - }) - .collect::>>()?; - - // XXX: by default in spm unk is 0 - // TODO: Check that we handle bos, eos correctly ! - let u = Unigram::from(&table, 0); - Ok(u) - } - /// Iterate of vocabulary of the model as a pair of `(token, score)`. pub fn iter(&self) -> UnigramIterator { UnigramIterator { model: self, i: 0 } @@ -252,12 +237,12 @@ pub struct UnigramIterator<'a> { } impl<'a> Iterator for UnigramIterator<'a> { - type Item = (&'a String, f64); + type Item = &'a (String, f64); fn next(&mut self) -> Option { let i = self.i; if i < self.model.len() { - let r = Some((&self.model.vocab[i], self.model.scores[i])); + let r = Some(&self.model.vocab[i]); self.i += 1; r } else { @@ -296,7 +281,7 @@ impl Model for Unigram { fn id_to_token(&self, id: u32) -> Option<&str> { match self.vocab.get(id as usize) { - Some(string) => Some(string), + Some(item) => Some(&item.0), None => None, } } @@ -322,7 +307,7 @@ mod tests { #[test] fn test_populate_nodes_unk() { let pieces = vec![("".to_string(), 0.0)]; - let model = Unigram::from(&pieces, 0); + let model = Unigram::from(&pieces, 0).unwrap(); let mut lattice = Lattice::from("abc", 0, model.bos_id, model.eos_id); model.populate_nodes(&mut lattice); @@ -347,7 +332,7 @@ mod tests { ("ab".to_string(), 0.3), ("bc".to_string(), 0.4), ]; - let model = Unigram::from(&pieces, 0); + let model = Unigram::from(&pieces, 0).unwrap(); let mut lattice = Lattice::from("abc", 0, model.bos_id, model.eos_id); model.populate_nodes(&mut lattice); @@ -384,7 +369,7 @@ mod tests { ("abcd".to_string(), 10.0), ]; - let model = Unigram::from(&sentencepieces, 0); + let model = Unigram::from(&sentencepieces, 0).unwrap(); let result = model.encode("abcd"); assert_eq!(result, vec!["abcd"]); } @@ -406,7 +391,7 @@ mod tests { ("qr".to_string(), -0.5), ]; - let mut model = Unigram::from(&sentencepieces, 0); + let mut model = Unigram::from(&sentencepieces, 0).unwrap(); assert_eq!(model.encode("abc"), vec!["abc"]); assert_eq!(model.encode("AB"), vec!["AB"]); diff --git a/tokenizers/src/models/unigram/normalize.rs b/tokenizers/src/models/unigram/normalize.rs deleted file mode 100644 index 8b137891..00000000 --- a/tokenizers/src/models/unigram/normalize.rs +++ /dev/null @@ -1 +0,0 @@ - diff --git a/tokenizers/src/models/unigram/serialization.rs b/tokenizers/src/models/unigram/serialization.rs index a06984ca..86bfaa54 100644 --- a/tokenizers/src/models/unigram/serialization.rs +++ b/tokenizers/src/models/unigram/serialization.rs @@ -13,9 +13,7 @@ impl Serialize for Unigram { let mut model = serializer.serialize_struct("Unigram", 2)?; model.serialize_field("unk_id", &self.unk_id)?; - - let vocab: Vec<(&String, &f64)> = self.vocab.iter().zip(self.scores.iter()).collect(); - model.serialize_field("vocab", &vocab)?; + model.serialize_field("vocab", &self.vocab)?; model.end() } @@ -54,7 +52,8 @@ impl<'de> Visitor<'de> for UnigramVisitor { } } match (vocab, unk_id) { - (Some(vocab), Some(unk_id)) => Ok(Unigram::from(&vocab, unk_id)), + (Some(vocab), Some(unk_id)) => Ok(Unigram::from(&vocab, unk_id) + .map_err(|err| Error::custom(&format!("Unable to load vocab {:?}", err)))?), (None, Some(_)) => Err(Error::custom("Missing vocab")), (None, None) => Err(Error::custom("Missing vocab and unk_id")), (Some(_), None) => Err(Error::custom("Missing unk_id")), @@ -69,7 +68,7 @@ mod test { #[test] fn test_serialization() { let vocab = vec![("".to_string(), 0.0), ("a".to_string(), -0.5)]; - let model = Unigram::from(&vocab, 0); + let model = Unigram::from(&vocab, 0).unwrap(); let data = serde_json::to_string(&model).unwrap(); let reconstructed = serde_json::from_str(&data).unwrap(); diff --git a/tokenizers/src/models/unigram/trainer.rs b/tokenizers/src/models/unigram/trainer.rs index 3e36b679..f9c6fc1d 100644 --- a/tokenizers/src/models/unigram/trainer.rs +++ b/tokenizers/src/models/unigram/trainer.rs @@ -167,7 +167,7 @@ impl UnigramTrainer { // true } - fn finalize(&self, model: Unigram, required_chars: HashSet) -> Unigram { + fn finalize(&self, model: Unigram, required_chars: HashSet) -> Result { // let mut pieces: Vec = // Vec::with_capacity(self.vocab_size.try_into().unwrap()); @@ -175,7 +175,7 @@ impl UnigramTrainer { let min_score_penalty_delta = 0.0001; let mut pieces: HashMap = HashMap::new(); - let existing_pieces: HashMap<&String, f64> = model.iter().collect(); + let existing_pieces: HashMap = model.iter().cloned().collect(); // XXX: Make sure bos, eos and unk exists and are ids 0, 1, 2 pieces.insert(self.unk_token.clone(), 0.0); for c in required_chars { @@ -191,7 +191,7 @@ impl UnigramTrainer { for (token, score) in model.iter() { match pieces.get(token) { Some(_) => continue, - None => pieces.insert(token.to_string(), score), + None => pieces.insert(token.to_string(), *score), }; if pieces.len() == self.vocab_size as usize { break; @@ -199,7 +199,7 @@ impl UnigramTrainer { } let mut final_pieces: Vec = pieces.into_iter().collect(); final_pieces.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap()); - Unigram::from(&final_pieces, 0) + Ok(Unigram::from(&final_pieces, 0).unwrap()) } fn required_chars(&self, word_counts: &[Sentence]) -> HashSet { @@ -546,7 +546,7 @@ impl UnigramTrainer { let expected_updates = expected_loops as usize * self.n_sub_iterations as usize; self.update_progress(&progress, expected_updates, "EM training"); let required_chars = self.required_chars(&sentences); - let mut model = Unigram::from(&pieces, 0); + let mut model = Unigram::from(&pieces, 0)?; loop { // Sub-EM iteration. for _iter in 0..self.n_sub_iterations { @@ -555,7 +555,7 @@ impl UnigramTrainer { // Executes M step. pieces = self.run_m_step(&pieces, &expected); - model = Unigram::from(&pieces, 0); + model = Unigram::from(&pieces, 0)?; // Useful comment for checking compatibility with spm println!( "Em iter={} size={} obj={} num_tokens={} num_tokens/piece={}", @@ -578,12 +578,12 @@ impl UnigramTrainer { // Prunes pieces. pieces = self.prune_sentence_pieces(&model, &pieces, &sentences); - model = Unigram::from(&pieces, 0); + model = Unigram::from(&pieces, 0)?; } self.finalize_progress(&progress, expected_updates); // Finally, adjusts the size of sentencepices to be |vocab_size|. - model = self.finalize(model, required_chars); + model = self.finalize(model, required_chars)?; Ok((model, self.special_tokens.clone())) } diff --git a/tokenizers/tests/unigram.rs b/tokenizers/tests/unigram.rs index a75c16e9..d76cd46b 100644 --- a/tokenizers/tests/unigram.rs +++ b/tokenizers/tests/unigram.rs @@ -41,32 +41,6 @@ fn test_unigram_from_file() { "。" ] ); - - // Check it works with spm_export_vocab model. - let model = Unigram::load_spm(Path::new("data/unigram.model")).unwrap(); - assert_eq!( - model - .tokenize(string) - .unwrap() - .iter() - .map(|tok| tok.value.clone()) - .collect::>(), - vec![ - "吾輩", - "《", - "わが", - "はい", - "》", - "は", - "猫", - "である", - "。", - "名前", - "はまだ", - "無い", - "。" - ] - ); } #[cfg(not(debug_assertions))]