mirror of
https://github.com/mii443/tokenizers.git
synced 2025-12-05 20:28:22 +00:00
Address @n1t0 comments.
This commit is contained in:
@@ -110,13 +110,8 @@ class Unigram(Model):
|
||||
vocab: ('`optional`) string:
|
||||
Path to a vocabulary JSON file.
|
||||
|
||||
is_spm_file: ('`optional`) bool:
|
||||
If the file came out of sentencepiece, we need to load it differently
|
||||
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def __init__(
|
||||
self, vocab: Optional[str], is_spm_file: Optional[bool],
|
||||
):
|
||||
def __init__(self, vocab: Optional[str]):
|
||||
pass
|
||||
|
||||
@@ -263,41 +263,15 @@ pub struct PyUnigram {}
|
||||
#[pymethods]
|
||||
impl PyUnigram {
|
||||
#[new]
|
||||
#[args(kwargs = "**")]
|
||||
fn new(vocab: Option<&str>, kwargs: Option<&PyDict>) -> PyResult<(Self, PyModel)> {
|
||||
let mut is_spm_file = false;
|
||||
if let Some(kwargs) = kwargs {
|
||||
for (key, val) in kwargs {
|
||||
let key: &str = key.extract()?;
|
||||
match key {
|
||||
"is_spm_file" => is_spm_file = val.extract()?,
|
||||
_ => println!("Ignored unknown kwargs option {}", key),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn new(vocab: Option<&str>) -> PyResult<(Self, PyModel)> {
|
||||
if let Some(vocab) = vocab {
|
||||
let path = Path::new(vocab);
|
||||
if is_spm_file {
|
||||
match Unigram::load_spm(path) {
|
||||
Err(e) => {
|
||||
println!("Errors: {:?}", e);
|
||||
Err(exceptions::Exception::py_err(
|
||||
"Error while initializing Unigram from spm file",
|
||||
))
|
||||
}
|
||||
Ok(model) => Ok((PyUnigram {}, PyModel::new(Arc::new(model.into())))),
|
||||
}
|
||||
} else {
|
||||
match Unigram::load(path) {
|
||||
Err(e) => {
|
||||
println!("Errors: {:?}", e);
|
||||
Err(exceptions::Exception::py_err(
|
||||
"Error while initializing Unigram",
|
||||
))
|
||||
}
|
||||
Ok(model) => Ok((PyUnigram {}, PyModel::new(Arc::new(model.into())))),
|
||||
match Unigram::load(path) {
|
||||
Err(e) => {
|
||||
println!("Errors: {:?}", e);
|
||||
Err(exceptions::Exception::py_err("Error while loading Unigram"))
|
||||
}
|
||||
Ok(model) => Ok((PyUnigram {}, PyModel::new(Arc::new(model.into())))),
|
||||
}
|
||||
} else {
|
||||
Ok((
|
||||
|
||||
@@ -190,25 +190,7 @@ impl PyUnigramTrainer {
|
||||
"show_progress" => builder.show_progress(val.extract()?),
|
||||
"n_sub_iterations" => builder.n_sub_iterations(val.extract()?),
|
||||
"shrinking_factor" => builder.shrinking_factor(val.extract()?),
|
||||
"space_char" => {
|
||||
let string: String = val.extract()?;
|
||||
if string.chars().collect::<Vec<_>>().len() != 1 {
|
||||
return Err(exceptions::Exception::py_err(
|
||||
"space_char must be 1 unicode char long",
|
||||
));
|
||||
}
|
||||
builder.space_char(string.chars().next().ok_or_else(|| {
|
||||
exceptions::Exception::py_err("space_char must not be 0 width")
|
||||
})?)
|
||||
}
|
||||
"unk_token" => builder.unk_token(val.extract()?),
|
||||
"split_by_number" => builder.split_by_number(val.extract()?),
|
||||
"treat_whitespace_as_suffix" => {
|
||||
builder.treat_whitespace_as_suffix(val.extract()?)
|
||||
}
|
||||
"split_by_unicode_script" => builder.split_by_unicode_script(val.extract()?),
|
||||
"split_by_digits" => builder.split_by_digits(val.extract()?),
|
||||
"split_by_whitespace" => builder.split_by_whitespace(val.extract()?),
|
||||
"max_piece_length" => builder.max_piece_length(val.extract()?),
|
||||
"seed_size" => builder.seed_size(val.extract()?),
|
||||
"special_tokens" => builder.special_tokens(
|
||||
|
||||
@@ -57,4 +57,3 @@ derive_builder = "0.9"
|
||||
criterion = "0.3"
|
||||
tempfile = "3.1"
|
||||
assert_approx_eq = "1.1"
|
||||
unicode-normalization = "0.1"
|
||||
|
||||
@@ -6,7 +6,7 @@ dir_guard=@mkdir -p $(@D)
|
||||
|
||||
SHARED_RESOURCES = $(DATA_DIR)/gpt2-vocab.json $(DATA_DIR)/gpt2-merges.txt $(DATA_DIR)/bert-base-uncased-vocab.txt
|
||||
BENCHMARK_RESOURCES = $(SHARED_RESOURCES) $(DATA_DIR)/big.txt $(DATA_DIR)/small.txt
|
||||
TESTS_RESOURCES = $(SHARED_RESOURCES) $(DATA_DIR)/unigram.json $(DATA_DIR)/unigram.model $(DATA_DIR)/unigram_wagahaiwa_nekodearu.txt
|
||||
TESTS_RESOURCES = $(SHARED_RESOURCES) $(DATA_DIR)/unigram.json $(DATA_DIR)/unigram_wagahaiwa_nekodearu.txt
|
||||
|
||||
.PHONY : build
|
||||
build :
|
||||
|
||||
@@ -5,18 +5,17 @@ use crate::tokenizer::{Model, Result, Token};
|
||||
use std::collections::HashMap;
|
||||
use std::convert::TryInto;
|
||||
use std::fs::File;
|
||||
use std::io::{BufRead, BufReader};
|
||||
use std::io::BufReader;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
type TokenMap = HashMap<String, u32>;
|
||||
type Vocab = Vec<String>;
|
||||
type Vocab = Vec<(String, f64)>;
|
||||
|
||||
/// A `Unigram` model to encode sentences.
|
||||
#[derive(Clone)]
|
||||
pub struct Unigram {
|
||||
token_to_ids: TokenMap,
|
||||
pub(crate) vocab: Vocab,
|
||||
pub(super) scores: Vec<f64>,
|
||||
trie: Trie<char>,
|
||||
pub min_score: f64,
|
||||
pub(super) unk_id: usize,
|
||||
@@ -27,10 +26,7 @@ pub struct Unigram {
|
||||
}
|
||||
impl PartialEq for Unigram {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
let vocab: Vec<(&String, &f64)> = self.vocab.iter().zip(self.scores.iter()).collect();
|
||||
let other_vocab: Vec<(&String, &f64)> =
|
||||
other.vocab.iter().zip(other.scores.iter()).collect();
|
||||
self.unk_id == other.unk_id && vocab == other_vocab
|
||||
self.unk_id == other.unk_id && self.vocab == other.vocab
|
||||
}
|
||||
}
|
||||
|
||||
@@ -44,10 +40,31 @@ impl std::fmt::Debug for Unigram {
|
||||
|
||||
static K_UNK_PENALTY: f64 = 10.0;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum UnigramError {
|
||||
EmptyVocabulary,
|
||||
UnkIdNotInVocabulary,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for UnigramError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
match self {
|
||||
UnigramError::EmptyVocabulary => {
|
||||
write!(f, "The vocabulary is empty but at least <unk> is needed")
|
||||
}
|
||||
UnigramError::UnkIdNotInVocabulary => {
|
||||
write!(f, "The `unk_id` is larger than vocabulary size")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for UnigramError {}
|
||||
|
||||
impl Default for Unigram {
|
||||
fn default() -> Self {
|
||||
let vocab = vec![("<unk>".to_string(), 0.0)];
|
||||
Self::from(&vocab, 0)
|
||||
Self::from(&vocab, 0).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -58,36 +75,39 @@ impl Unigram {
|
||||
/// unk_id, is the index within the vocabulary.
|
||||
/// For now `Unigram` *requires* at least `unk` because we might find a never seen char.
|
||||
/// Further versions might allow that part to be hidden.
|
||||
pub fn from(vocabulary: &[(String, f64)], unk_id: usize) -> Self {
|
||||
pub fn from(
|
||||
vocabulary: &[(String, f64)],
|
||||
unk_id: usize,
|
||||
) -> std::result::Result<Self, UnigramError> {
|
||||
let n = vocabulary.len();
|
||||
let mut vocab: Vec<String> = Vec::with_capacity(n);
|
||||
let mut scores: Vec<f64> = Vec::with_capacity(n);
|
||||
let vocab: Vec<(String, f64)> = vocabulary.iter().cloned().collect();
|
||||
let mut token_to_ids: TokenMap = HashMap::new();
|
||||
let mut builder = TrieBuilder::default();
|
||||
|
||||
assert!(n >= 1, "We need at least unk in the vocabulary");
|
||||
assert!(unk_id < vocabulary.len(), "Unk id is invalid");
|
||||
if vocabulary.is_empty() {
|
||||
return Err(UnigramError::EmptyVocabulary);
|
||||
}
|
||||
if unk_id >= vocabulary.len() {
|
||||
return Err(UnigramError::UnkIdNotInVocabulary);
|
||||
}
|
||||
|
||||
let bos_id = n + 1;
|
||||
let eos_id = n + 2;
|
||||
|
||||
let mut min_score = f64::INFINITY;
|
||||
for (id, (token, score)) in vocabulary.iter().enumerate() {
|
||||
vocab.push(token.to_string());
|
||||
scores.push(*score);
|
||||
token_to_ids.insert(token.to_string(), id as u32);
|
||||
let chars: Vec<char> = token.chars().collect();
|
||||
builder.push(&chars);
|
||||
}
|
||||
let min_score = scores.iter().fold(f64::INFINITY, |a, &b| a.min(b));
|
||||
if min_score == -f64::INFINITY {
|
||||
panic!("Alert min_score !!");
|
||||
if score < &min_score {
|
||||
min_score = *score;
|
||||
}
|
||||
}
|
||||
let trie = builder.build();
|
||||
let fuse_unk = true;
|
||||
|
||||
Unigram {
|
||||
Ok(Unigram {
|
||||
vocab,
|
||||
scores,
|
||||
token_to_ids,
|
||||
trie,
|
||||
min_score,
|
||||
@@ -95,7 +115,7 @@ impl Unigram {
|
||||
eos_id,
|
||||
unk_id,
|
||||
fuse_unk,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -124,8 +144,10 @@ impl Unigram {
|
||||
let n = result.len();
|
||||
let tok: String = result.into_iter().collect();
|
||||
let id = *self.token_to_ids.get(&tok).unwrap();
|
||||
assert_eq!(self.vocab[id as usize], tok);
|
||||
let score: f64 = self.scores[id as usize];
|
||||
|
||||
let item = &self.vocab[id as usize];
|
||||
assert_eq!(item.0, tok);
|
||||
let score: f64 = item.1;
|
||||
lattice.insert(begin_pos, n, score, id.try_into().unwrap());
|
||||
if !has_single_node && n == 1 {
|
||||
has_single_node = true;
|
||||
@@ -154,7 +176,7 @@ impl Unigram {
|
||||
/// ("abc".to_string(), 5.0),
|
||||
/// ("abcd".to_string(), 10.0),
|
||||
/// ];
|
||||
/// let model = Unigram::from(&pieces, 0);
|
||||
/// let model = Unigram::from(&pieces, 0).unwrap();
|
||||
/// let result = model.encode("abcdacdxx");
|
||||
/// assert_eq!(result, vec!["abcd", "a", "cd", "xx"]);
|
||||
/// ```
|
||||
@@ -187,43 +209,6 @@ impl Unigram {
|
||||
}
|
||||
}
|
||||
|
||||
/// Loads a SentencePiece output model.
|
||||
/// In order to get the proper model with spm.
|
||||
///
|
||||
/// ```ignore
|
||||
/// spm_train --model=unigram --input=.... --model_prefix=myprefix ...
|
||||
/// spm_export_vocab --model=myprefix.model --output=myprefix.txt
|
||||
/// ```
|
||||
///
|
||||
/// After that you can use the model with tokenizers library.
|
||||
/// ```no_run
|
||||
/// use tokenizers::models::unigram::Unigram;
|
||||
/// use std::path::Path;
|
||||
///
|
||||
/// let model = Unigram::load_spm(Path::new("myprefix.txt")).unwrap();
|
||||
/// ```
|
||||
pub fn load_spm<P: AsRef<Path>>(path: P) -> Result<Unigram> {
|
||||
let file = BufReader::new(File::open(path)?);
|
||||
let table = file
|
||||
.lines()
|
||||
.enumerate()
|
||||
.map(|(i, line)| {
|
||||
let line = line?;
|
||||
let newline = line.replace('▁', " ");
|
||||
let tokens: Vec<_> = newline.split('\t').collect();
|
||||
match tokens.as_slice() {
|
||||
[token, score] => Ok((token.to_string(), score.parse()?)),
|
||||
_ => Err(format!("Line {} is invalid {:?}", i, line).into()),
|
||||
}
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
// XXX: by default in spm unk is 0
|
||||
// TODO: Check that we handle bos, eos correctly !
|
||||
let u = Unigram::from(&table, 0);
|
||||
Ok(u)
|
||||
}
|
||||
|
||||
/// Iterate of vocabulary of the model as a pair of `(token, score)`.
|
||||
pub fn iter(&self) -> UnigramIterator {
|
||||
UnigramIterator { model: self, i: 0 }
|
||||
@@ -252,12 +237,12 @@ pub struct UnigramIterator<'a> {
|
||||
}
|
||||
|
||||
impl<'a> Iterator for UnigramIterator<'a> {
|
||||
type Item = (&'a String, f64);
|
||||
type Item = &'a (String, f64);
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
let i = self.i;
|
||||
if i < self.model.len() {
|
||||
let r = Some((&self.model.vocab[i], self.model.scores[i]));
|
||||
let r = Some(&self.model.vocab[i]);
|
||||
self.i += 1;
|
||||
r
|
||||
} else {
|
||||
@@ -296,7 +281,7 @@ impl Model for Unigram {
|
||||
|
||||
fn id_to_token(&self, id: u32) -> Option<&str> {
|
||||
match self.vocab.get(id as usize) {
|
||||
Some(string) => Some(string),
|
||||
Some(item) => Some(&item.0),
|
||||
None => None,
|
||||
}
|
||||
}
|
||||
@@ -322,7 +307,7 @@ mod tests {
|
||||
#[test]
|
||||
fn test_populate_nodes_unk() {
|
||||
let pieces = vec![("<unk>".to_string(), 0.0)];
|
||||
let model = Unigram::from(&pieces, 0);
|
||||
let model = Unigram::from(&pieces, 0).unwrap();
|
||||
|
||||
let mut lattice = Lattice::from("abc", 0, model.bos_id, model.eos_id);
|
||||
model.populate_nodes(&mut lattice);
|
||||
@@ -347,7 +332,7 @@ mod tests {
|
||||
("ab".to_string(), 0.3),
|
||||
("bc".to_string(), 0.4),
|
||||
];
|
||||
let model = Unigram::from(&pieces, 0);
|
||||
let model = Unigram::from(&pieces, 0).unwrap();
|
||||
|
||||
let mut lattice = Lattice::from("abc", 0, model.bos_id, model.eos_id);
|
||||
model.populate_nodes(&mut lattice);
|
||||
@@ -384,7 +369,7 @@ mod tests {
|
||||
("abcd".to_string(), 10.0),
|
||||
];
|
||||
|
||||
let model = Unigram::from(&sentencepieces, 0);
|
||||
let model = Unigram::from(&sentencepieces, 0).unwrap();
|
||||
let result = model.encode("abcd");
|
||||
assert_eq!(result, vec!["abcd"]);
|
||||
}
|
||||
@@ -406,7 +391,7 @@ mod tests {
|
||||
("qr".to_string(), -0.5),
|
||||
];
|
||||
|
||||
let mut model = Unigram::from(&sentencepieces, 0);
|
||||
let mut model = Unigram::from(&sentencepieces, 0).unwrap();
|
||||
assert_eq!(model.encode("abc"), vec!["abc"]);
|
||||
assert_eq!(model.encode("AB"), vec!["AB"]);
|
||||
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
|
||||
@@ -13,9 +13,7 @@ impl Serialize for Unigram {
|
||||
let mut model = serializer.serialize_struct("Unigram", 2)?;
|
||||
|
||||
model.serialize_field("unk_id", &self.unk_id)?;
|
||||
|
||||
let vocab: Vec<(&String, &f64)> = self.vocab.iter().zip(self.scores.iter()).collect();
|
||||
model.serialize_field("vocab", &vocab)?;
|
||||
model.serialize_field("vocab", &self.vocab)?;
|
||||
|
||||
model.end()
|
||||
}
|
||||
@@ -54,7 +52,8 @@ impl<'de> Visitor<'de> for UnigramVisitor {
|
||||
}
|
||||
}
|
||||
match (vocab, unk_id) {
|
||||
(Some(vocab), Some(unk_id)) => Ok(Unigram::from(&vocab, unk_id)),
|
||||
(Some(vocab), Some(unk_id)) => Ok(Unigram::from(&vocab, unk_id)
|
||||
.map_err(|err| Error::custom(&format!("Unable to load vocab {:?}", err)))?),
|
||||
(None, Some(_)) => Err(Error::custom("Missing vocab")),
|
||||
(None, None) => Err(Error::custom("Missing vocab and unk_id")),
|
||||
(Some(_), None) => Err(Error::custom("Missing unk_id")),
|
||||
@@ -69,7 +68,7 @@ mod test {
|
||||
#[test]
|
||||
fn test_serialization() {
|
||||
let vocab = vec![("<unk>".to_string(), 0.0), ("a".to_string(), -0.5)];
|
||||
let model = Unigram::from(&vocab, 0);
|
||||
let model = Unigram::from(&vocab, 0).unwrap();
|
||||
|
||||
let data = serde_json::to_string(&model).unwrap();
|
||||
let reconstructed = serde_json::from_str(&data).unwrap();
|
||||
|
||||
@@ -167,7 +167,7 @@ impl UnigramTrainer {
|
||||
// true
|
||||
}
|
||||
|
||||
fn finalize(&self, model: Unigram, required_chars: HashSet<String>) -> Unigram {
|
||||
fn finalize(&self, model: Unigram, required_chars: HashSet<String>) -> Result<Unigram> {
|
||||
// let mut pieces: Vec<SentencePiece> =
|
||||
// Vec::with_capacity(self.vocab_size.try_into().unwrap());
|
||||
|
||||
@@ -175,7 +175,7 @@ impl UnigramTrainer {
|
||||
let min_score_penalty_delta = 0.0001;
|
||||
|
||||
let mut pieces: HashMap<String, f64> = HashMap::new();
|
||||
let existing_pieces: HashMap<&String, f64> = model.iter().collect();
|
||||
let existing_pieces: HashMap<String, f64> = model.iter().cloned().collect();
|
||||
// XXX: Make sure bos, eos and unk exists and are ids 0, 1, 2
|
||||
pieces.insert(self.unk_token.clone(), 0.0);
|
||||
for c in required_chars {
|
||||
@@ -191,7 +191,7 @@ impl UnigramTrainer {
|
||||
for (token, score) in model.iter() {
|
||||
match pieces.get(token) {
|
||||
Some(_) => continue,
|
||||
None => pieces.insert(token.to_string(), score),
|
||||
None => pieces.insert(token.to_string(), *score),
|
||||
};
|
||||
if pieces.len() == self.vocab_size as usize {
|
||||
break;
|
||||
@@ -199,7 +199,7 @@ impl UnigramTrainer {
|
||||
}
|
||||
let mut final_pieces: Vec<SentencePiece> = pieces.into_iter().collect();
|
||||
final_pieces.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap());
|
||||
Unigram::from(&final_pieces, 0)
|
||||
Ok(Unigram::from(&final_pieces, 0).unwrap())
|
||||
}
|
||||
|
||||
fn required_chars(&self, word_counts: &[Sentence]) -> HashSet<String> {
|
||||
@@ -546,7 +546,7 @@ impl UnigramTrainer {
|
||||
let expected_updates = expected_loops as usize * self.n_sub_iterations as usize;
|
||||
self.update_progress(&progress, expected_updates, "EM training");
|
||||
let required_chars = self.required_chars(&sentences);
|
||||
let mut model = Unigram::from(&pieces, 0);
|
||||
let mut model = Unigram::from(&pieces, 0)?;
|
||||
loop {
|
||||
// Sub-EM iteration.
|
||||
for _iter in 0..self.n_sub_iterations {
|
||||
@@ -555,7 +555,7 @@ impl UnigramTrainer {
|
||||
|
||||
// Executes M step.
|
||||
pieces = self.run_m_step(&pieces, &expected);
|
||||
model = Unigram::from(&pieces, 0);
|
||||
model = Unigram::from(&pieces, 0)?;
|
||||
// Useful comment for checking compatibility with spm
|
||||
println!(
|
||||
"Em iter={} size={} obj={} num_tokens={} num_tokens/piece={}",
|
||||
@@ -578,12 +578,12 @@ impl UnigramTrainer {
|
||||
|
||||
// Prunes pieces.
|
||||
pieces = self.prune_sentence_pieces(&model, &pieces, &sentences);
|
||||
model = Unigram::from(&pieces, 0);
|
||||
model = Unigram::from(&pieces, 0)?;
|
||||
}
|
||||
self.finalize_progress(&progress, expected_updates);
|
||||
|
||||
// Finally, adjusts the size of sentencepices to be |vocab_size|.
|
||||
model = self.finalize(model, required_chars);
|
||||
model = self.finalize(model, required_chars)?;
|
||||
|
||||
Ok((model, self.special_tokens.clone()))
|
||||
}
|
||||
|
||||
@@ -41,32 +41,6 @@ fn test_unigram_from_file() {
|
||||
"。"
|
||||
]
|
||||
);
|
||||
|
||||
// Check it works with spm_export_vocab model.
|
||||
let model = Unigram::load_spm(Path::new("data/unigram.model")).unwrap();
|
||||
assert_eq!(
|
||||
model
|
||||
.tokenize(string)
|
||||
.unwrap()
|
||||
.iter()
|
||||
.map(|tok| tok.value.clone())
|
||||
.collect::<Vec<_>>(),
|
||||
vec![
|
||||
"吾輩",
|
||||
"《",
|
||||
"わが",
|
||||
"はい",
|
||||
"》",
|
||||
"は",
|
||||
"猫",
|
||||
"である",
|
||||
"。",
|
||||
"名前",
|
||||
"はまだ",
|
||||
"無い",
|
||||
"。"
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(not(debug_assertions))]
|
||||
|
||||
Reference in New Issue
Block a user