mirror of
https://github.com/mii443/tokenizers.git
synced 2025-12-06 04:38:23 +00:00
Address @n1t0 comments.
This commit is contained in:
@@ -110,13 +110,8 @@ class Unigram(Model):
|
|||||||
vocab: ('`optional`) string:
|
vocab: ('`optional`) string:
|
||||||
Path to a vocabulary JSON file.
|
Path to a vocabulary JSON file.
|
||||||
|
|
||||||
is_spm_file: ('`optional`) bool:
|
|
||||||
If the file came out of sentencepiece, we need to load it differently
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def __init__(
|
def __init__(self, vocab: Optional[str]):
|
||||||
self, vocab: Optional[str], is_spm_file: Optional[bool],
|
|
||||||
):
|
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -263,42 +263,16 @@ pub struct PyUnigram {}
|
|||||||
#[pymethods]
|
#[pymethods]
|
||||||
impl PyUnigram {
|
impl PyUnigram {
|
||||||
#[new]
|
#[new]
|
||||||
#[args(kwargs = "**")]
|
fn new(vocab: Option<&str>) -> PyResult<(Self, PyModel)> {
|
||||||
fn new(vocab: Option<&str>, kwargs: Option<&PyDict>) -> PyResult<(Self, PyModel)> {
|
|
||||||
let mut is_spm_file = false;
|
|
||||||
if let Some(kwargs) = kwargs {
|
|
||||||
for (key, val) in kwargs {
|
|
||||||
let key: &str = key.extract()?;
|
|
||||||
match key {
|
|
||||||
"is_spm_file" => is_spm_file = val.extract()?,
|
|
||||||
_ => println!("Ignored unknown kwargs option {}", key),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(vocab) = vocab {
|
if let Some(vocab) = vocab {
|
||||||
let path = Path::new(vocab);
|
let path = Path::new(vocab);
|
||||||
if is_spm_file {
|
|
||||||
match Unigram::load_spm(path) {
|
|
||||||
Err(e) => {
|
|
||||||
println!("Errors: {:?}", e);
|
|
||||||
Err(exceptions::Exception::py_err(
|
|
||||||
"Error while initializing Unigram from spm file",
|
|
||||||
))
|
|
||||||
}
|
|
||||||
Ok(model) => Ok((PyUnigram {}, PyModel::new(Arc::new(model.into())))),
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
match Unigram::load(path) {
|
match Unigram::load(path) {
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
println!("Errors: {:?}", e);
|
println!("Errors: {:?}", e);
|
||||||
Err(exceptions::Exception::py_err(
|
Err(exceptions::Exception::py_err("Error while loading Unigram"))
|
||||||
"Error while initializing Unigram",
|
|
||||||
))
|
|
||||||
}
|
}
|
||||||
Ok(model) => Ok((PyUnigram {}, PyModel::new(Arc::new(model.into())))),
|
Ok(model) => Ok((PyUnigram {}, PyModel::new(Arc::new(model.into())))),
|
||||||
}
|
}
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
Ok((
|
Ok((
|
||||||
PyUnigram {},
|
PyUnigram {},
|
||||||
|
|||||||
@@ -190,25 +190,7 @@ impl PyUnigramTrainer {
|
|||||||
"show_progress" => builder.show_progress(val.extract()?),
|
"show_progress" => builder.show_progress(val.extract()?),
|
||||||
"n_sub_iterations" => builder.n_sub_iterations(val.extract()?),
|
"n_sub_iterations" => builder.n_sub_iterations(val.extract()?),
|
||||||
"shrinking_factor" => builder.shrinking_factor(val.extract()?),
|
"shrinking_factor" => builder.shrinking_factor(val.extract()?),
|
||||||
"space_char" => {
|
|
||||||
let string: String = val.extract()?;
|
|
||||||
if string.chars().collect::<Vec<_>>().len() != 1 {
|
|
||||||
return Err(exceptions::Exception::py_err(
|
|
||||||
"space_char must be 1 unicode char long",
|
|
||||||
));
|
|
||||||
}
|
|
||||||
builder.space_char(string.chars().next().ok_or_else(|| {
|
|
||||||
exceptions::Exception::py_err("space_char must not be 0 width")
|
|
||||||
})?)
|
|
||||||
}
|
|
||||||
"unk_token" => builder.unk_token(val.extract()?),
|
"unk_token" => builder.unk_token(val.extract()?),
|
||||||
"split_by_number" => builder.split_by_number(val.extract()?),
|
|
||||||
"treat_whitespace_as_suffix" => {
|
|
||||||
builder.treat_whitespace_as_suffix(val.extract()?)
|
|
||||||
}
|
|
||||||
"split_by_unicode_script" => builder.split_by_unicode_script(val.extract()?),
|
|
||||||
"split_by_digits" => builder.split_by_digits(val.extract()?),
|
|
||||||
"split_by_whitespace" => builder.split_by_whitespace(val.extract()?),
|
|
||||||
"max_piece_length" => builder.max_piece_length(val.extract()?),
|
"max_piece_length" => builder.max_piece_length(val.extract()?),
|
||||||
"seed_size" => builder.seed_size(val.extract()?),
|
"seed_size" => builder.seed_size(val.extract()?),
|
||||||
"special_tokens" => builder.special_tokens(
|
"special_tokens" => builder.special_tokens(
|
||||||
|
|||||||
@@ -57,4 +57,3 @@ derive_builder = "0.9"
|
|||||||
criterion = "0.3"
|
criterion = "0.3"
|
||||||
tempfile = "3.1"
|
tempfile = "3.1"
|
||||||
assert_approx_eq = "1.1"
|
assert_approx_eq = "1.1"
|
||||||
unicode-normalization = "0.1"
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ dir_guard=@mkdir -p $(@D)
|
|||||||
|
|
||||||
SHARED_RESOURCES = $(DATA_DIR)/gpt2-vocab.json $(DATA_DIR)/gpt2-merges.txt $(DATA_DIR)/bert-base-uncased-vocab.txt
|
SHARED_RESOURCES = $(DATA_DIR)/gpt2-vocab.json $(DATA_DIR)/gpt2-merges.txt $(DATA_DIR)/bert-base-uncased-vocab.txt
|
||||||
BENCHMARK_RESOURCES = $(SHARED_RESOURCES) $(DATA_DIR)/big.txt $(DATA_DIR)/small.txt
|
BENCHMARK_RESOURCES = $(SHARED_RESOURCES) $(DATA_DIR)/big.txt $(DATA_DIR)/small.txt
|
||||||
TESTS_RESOURCES = $(SHARED_RESOURCES) $(DATA_DIR)/unigram.json $(DATA_DIR)/unigram.model $(DATA_DIR)/unigram_wagahaiwa_nekodearu.txt
|
TESTS_RESOURCES = $(SHARED_RESOURCES) $(DATA_DIR)/unigram.json $(DATA_DIR)/unigram_wagahaiwa_nekodearu.txt
|
||||||
|
|
||||||
.PHONY : build
|
.PHONY : build
|
||||||
build :
|
build :
|
||||||
|
|||||||
@@ -5,18 +5,17 @@ use crate::tokenizer::{Model, Result, Token};
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::convert::TryInto;
|
use std::convert::TryInto;
|
||||||
use std::fs::File;
|
use std::fs::File;
|
||||||
use std::io::{BufRead, BufReader};
|
use std::io::BufReader;
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
|
|
||||||
type TokenMap = HashMap<String, u32>;
|
type TokenMap = HashMap<String, u32>;
|
||||||
type Vocab = Vec<String>;
|
type Vocab = Vec<(String, f64)>;
|
||||||
|
|
||||||
/// A `Unigram` model to encode sentences.
|
/// A `Unigram` model to encode sentences.
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct Unigram {
|
pub struct Unigram {
|
||||||
token_to_ids: TokenMap,
|
token_to_ids: TokenMap,
|
||||||
pub(crate) vocab: Vocab,
|
pub(crate) vocab: Vocab,
|
||||||
pub(super) scores: Vec<f64>,
|
|
||||||
trie: Trie<char>,
|
trie: Trie<char>,
|
||||||
pub min_score: f64,
|
pub min_score: f64,
|
||||||
pub(super) unk_id: usize,
|
pub(super) unk_id: usize,
|
||||||
@@ -27,10 +26,7 @@ pub struct Unigram {
|
|||||||
}
|
}
|
||||||
impl PartialEq for Unigram {
|
impl PartialEq for Unigram {
|
||||||
fn eq(&self, other: &Self) -> bool {
|
fn eq(&self, other: &Self) -> bool {
|
||||||
let vocab: Vec<(&String, &f64)> = self.vocab.iter().zip(self.scores.iter()).collect();
|
self.unk_id == other.unk_id && self.vocab == other.vocab
|
||||||
let other_vocab: Vec<(&String, &f64)> =
|
|
||||||
other.vocab.iter().zip(other.scores.iter()).collect();
|
|
||||||
self.unk_id == other.unk_id && vocab == other_vocab
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -44,10 +40,31 @@ impl std::fmt::Debug for Unigram {
|
|||||||
|
|
||||||
static K_UNK_PENALTY: f64 = 10.0;
|
static K_UNK_PENALTY: f64 = 10.0;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum UnigramError {
|
||||||
|
EmptyVocabulary,
|
||||||
|
UnkIdNotInVocabulary,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for UnigramError {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
UnigramError::EmptyVocabulary => {
|
||||||
|
write!(f, "The vocabulary is empty but at least <unk> is needed")
|
||||||
|
}
|
||||||
|
UnigramError::UnkIdNotInVocabulary => {
|
||||||
|
write!(f, "The `unk_id` is larger than vocabulary size")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::error::Error for UnigramError {}
|
||||||
|
|
||||||
impl Default for Unigram {
|
impl Default for Unigram {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
let vocab = vec![("<unk>".to_string(), 0.0)];
|
let vocab = vec![("<unk>".to_string(), 0.0)];
|
||||||
Self::from(&vocab, 0)
|
Self::from(&vocab, 0).unwrap()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -58,36 +75,39 @@ impl Unigram {
|
|||||||
/// unk_id, is the index within the vocabulary.
|
/// unk_id, is the index within the vocabulary.
|
||||||
/// For now `Unigram` *requires* at least `unk` because we might find a never seen char.
|
/// For now `Unigram` *requires* at least `unk` because we might find a never seen char.
|
||||||
/// Further versions might allow that part to be hidden.
|
/// Further versions might allow that part to be hidden.
|
||||||
pub fn from(vocabulary: &[(String, f64)], unk_id: usize) -> Self {
|
pub fn from(
|
||||||
|
vocabulary: &[(String, f64)],
|
||||||
|
unk_id: usize,
|
||||||
|
) -> std::result::Result<Self, UnigramError> {
|
||||||
let n = vocabulary.len();
|
let n = vocabulary.len();
|
||||||
let mut vocab: Vec<String> = Vec::with_capacity(n);
|
let vocab: Vec<(String, f64)> = vocabulary.iter().cloned().collect();
|
||||||
let mut scores: Vec<f64> = Vec::with_capacity(n);
|
|
||||||
let mut token_to_ids: TokenMap = HashMap::new();
|
let mut token_to_ids: TokenMap = HashMap::new();
|
||||||
let mut builder = TrieBuilder::default();
|
let mut builder = TrieBuilder::default();
|
||||||
|
|
||||||
assert!(n >= 1, "We need at least unk in the vocabulary");
|
if vocabulary.is_empty() {
|
||||||
assert!(unk_id < vocabulary.len(), "Unk id is invalid");
|
return Err(UnigramError::EmptyVocabulary);
|
||||||
|
}
|
||||||
|
if unk_id >= vocabulary.len() {
|
||||||
|
return Err(UnigramError::UnkIdNotInVocabulary);
|
||||||
|
}
|
||||||
|
|
||||||
let bos_id = n + 1;
|
let bos_id = n + 1;
|
||||||
let eos_id = n + 2;
|
let eos_id = n + 2;
|
||||||
|
|
||||||
|
let mut min_score = f64::INFINITY;
|
||||||
for (id, (token, score)) in vocabulary.iter().enumerate() {
|
for (id, (token, score)) in vocabulary.iter().enumerate() {
|
||||||
vocab.push(token.to_string());
|
|
||||||
scores.push(*score);
|
|
||||||
token_to_ids.insert(token.to_string(), id as u32);
|
token_to_ids.insert(token.to_string(), id as u32);
|
||||||
let chars: Vec<char> = token.chars().collect();
|
let chars: Vec<char> = token.chars().collect();
|
||||||
builder.push(&chars);
|
builder.push(&chars);
|
||||||
|
if score < &min_score {
|
||||||
|
min_score = *score;
|
||||||
}
|
}
|
||||||
let min_score = scores.iter().fold(f64::INFINITY, |a, &b| a.min(b));
|
|
||||||
if min_score == -f64::INFINITY {
|
|
||||||
panic!("Alert min_score !!");
|
|
||||||
}
|
}
|
||||||
let trie = builder.build();
|
let trie = builder.build();
|
||||||
let fuse_unk = true;
|
let fuse_unk = true;
|
||||||
|
|
||||||
Unigram {
|
Ok(Unigram {
|
||||||
vocab,
|
vocab,
|
||||||
scores,
|
|
||||||
token_to_ids,
|
token_to_ids,
|
||||||
trie,
|
trie,
|
||||||
min_score,
|
min_score,
|
||||||
@@ -95,7 +115,7 @@ impl Unigram {
|
|||||||
eos_id,
|
eos_id,
|
||||||
unk_id,
|
unk_id,
|
||||||
fuse_unk,
|
fuse_unk,
|
||||||
}
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@@ -124,8 +144,10 @@ impl Unigram {
|
|||||||
let n = result.len();
|
let n = result.len();
|
||||||
let tok: String = result.into_iter().collect();
|
let tok: String = result.into_iter().collect();
|
||||||
let id = *self.token_to_ids.get(&tok).unwrap();
|
let id = *self.token_to_ids.get(&tok).unwrap();
|
||||||
assert_eq!(self.vocab[id as usize], tok);
|
|
||||||
let score: f64 = self.scores[id as usize];
|
let item = &self.vocab[id as usize];
|
||||||
|
assert_eq!(item.0, tok);
|
||||||
|
let score: f64 = item.1;
|
||||||
lattice.insert(begin_pos, n, score, id.try_into().unwrap());
|
lattice.insert(begin_pos, n, score, id.try_into().unwrap());
|
||||||
if !has_single_node && n == 1 {
|
if !has_single_node && n == 1 {
|
||||||
has_single_node = true;
|
has_single_node = true;
|
||||||
@@ -154,7 +176,7 @@ impl Unigram {
|
|||||||
/// ("abc".to_string(), 5.0),
|
/// ("abc".to_string(), 5.0),
|
||||||
/// ("abcd".to_string(), 10.0),
|
/// ("abcd".to_string(), 10.0),
|
||||||
/// ];
|
/// ];
|
||||||
/// let model = Unigram::from(&pieces, 0);
|
/// let model = Unigram::from(&pieces, 0).unwrap();
|
||||||
/// let result = model.encode("abcdacdxx");
|
/// let result = model.encode("abcdacdxx");
|
||||||
/// assert_eq!(result, vec!["abcd", "a", "cd", "xx"]);
|
/// assert_eq!(result, vec!["abcd", "a", "cd", "xx"]);
|
||||||
/// ```
|
/// ```
|
||||||
@@ -187,43 +209,6 @@ impl Unigram {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Loads a SentencePiece output model.
|
|
||||||
/// In order to get the proper model with spm.
|
|
||||||
///
|
|
||||||
/// ```ignore
|
|
||||||
/// spm_train --model=unigram --input=.... --model_prefix=myprefix ...
|
|
||||||
/// spm_export_vocab --model=myprefix.model --output=myprefix.txt
|
|
||||||
/// ```
|
|
||||||
///
|
|
||||||
/// After that you can use the model with tokenizers library.
|
|
||||||
/// ```no_run
|
|
||||||
/// use tokenizers::models::unigram::Unigram;
|
|
||||||
/// use std::path::Path;
|
|
||||||
///
|
|
||||||
/// let model = Unigram::load_spm(Path::new("myprefix.txt")).unwrap();
|
|
||||||
/// ```
|
|
||||||
pub fn load_spm<P: AsRef<Path>>(path: P) -> Result<Unigram> {
|
|
||||||
let file = BufReader::new(File::open(path)?);
|
|
||||||
let table = file
|
|
||||||
.lines()
|
|
||||||
.enumerate()
|
|
||||||
.map(|(i, line)| {
|
|
||||||
let line = line?;
|
|
||||||
let newline = line.replace('▁', " ");
|
|
||||||
let tokens: Vec<_> = newline.split('\t').collect();
|
|
||||||
match tokens.as_slice() {
|
|
||||||
[token, score] => Ok((token.to_string(), score.parse()?)),
|
|
||||||
_ => Err(format!("Line {} is invalid {:?}", i, line).into()),
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.collect::<Result<Vec<_>>>()?;
|
|
||||||
|
|
||||||
// XXX: by default in spm unk is 0
|
|
||||||
// TODO: Check that we handle bos, eos correctly !
|
|
||||||
let u = Unigram::from(&table, 0);
|
|
||||||
Ok(u)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Iterate of vocabulary of the model as a pair of `(token, score)`.
|
/// Iterate of vocabulary of the model as a pair of `(token, score)`.
|
||||||
pub fn iter(&self) -> UnigramIterator {
|
pub fn iter(&self) -> UnigramIterator {
|
||||||
UnigramIterator { model: self, i: 0 }
|
UnigramIterator { model: self, i: 0 }
|
||||||
@@ -252,12 +237,12 @@ pub struct UnigramIterator<'a> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> Iterator for UnigramIterator<'a> {
|
impl<'a> Iterator for UnigramIterator<'a> {
|
||||||
type Item = (&'a String, f64);
|
type Item = &'a (String, f64);
|
||||||
|
|
||||||
fn next(&mut self) -> Option<Self::Item> {
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
let i = self.i;
|
let i = self.i;
|
||||||
if i < self.model.len() {
|
if i < self.model.len() {
|
||||||
let r = Some((&self.model.vocab[i], self.model.scores[i]));
|
let r = Some(&self.model.vocab[i]);
|
||||||
self.i += 1;
|
self.i += 1;
|
||||||
r
|
r
|
||||||
} else {
|
} else {
|
||||||
@@ -296,7 +281,7 @@ impl Model for Unigram {
|
|||||||
|
|
||||||
fn id_to_token(&self, id: u32) -> Option<&str> {
|
fn id_to_token(&self, id: u32) -> Option<&str> {
|
||||||
match self.vocab.get(id as usize) {
|
match self.vocab.get(id as usize) {
|
||||||
Some(string) => Some(string),
|
Some(item) => Some(&item.0),
|
||||||
None => None,
|
None => None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -322,7 +307,7 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_populate_nodes_unk() {
|
fn test_populate_nodes_unk() {
|
||||||
let pieces = vec![("<unk>".to_string(), 0.0)];
|
let pieces = vec![("<unk>".to_string(), 0.0)];
|
||||||
let model = Unigram::from(&pieces, 0);
|
let model = Unigram::from(&pieces, 0).unwrap();
|
||||||
|
|
||||||
let mut lattice = Lattice::from("abc", 0, model.bos_id, model.eos_id);
|
let mut lattice = Lattice::from("abc", 0, model.bos_id, model.eos_id);
|
||||||
model.populate_nodes(&mut lattice);
|
model.populate_nodes(&mut lattice);
|
||||||
@@ -347,7 +332,7 @@ mod tests {
|
|||||||
("ab".to_string(), 0.3),
|
("ab".to_string(), 0.3),
|
||||||
("bc".to_string(), 0.4),
|
("bc".to_string(), 0.4),
|
||||||
];
|
];
|
||||||
let model = Unigram::from(&pieces, 0);
|
let model = Unigram::from(&pieces, 0).unwrap();
|
||||||
|
|
||||||
let mut lattice = Lattice::from("abc", 0, model.bos_id, model.eos_id);
|
let mut lattice = Lattice::from("abc", 0, model.bos_id, model.eos_id);
|
||||||
model.populate_nodes(&mut lattice);
|
model.populate_nodes(&mut lattice);
|
||||||
@@ -384,7 +369,7 @@ mod tests {
|
|||||||
("abcd".to_string(), 10.0),
|
("abcd".to_string(), 10.0),
|
||||||
];
|
];
|
||||||
|
|
||||||
let model = Unigram::from(&sentencepieces, 0);
|
let model = Unigram::from(&sentencepieces, 0).unwrap();
|
||||||
let result = model.encode("abcd");
|
let result = model.encode("abcd");
|
||||||
assert_eq!(result, vec!["abcd"]);
|
assert_eq!(result, vec!["abcd"]);
|
||||||
}
|
}
|
||||||
@@ -406,7 +391,7 @@ mod tests {
|
|||||||
("qr".to_string(), -0.5),
|
("qr".to_string(), -0.5),
|
||||||
];
|
];
|
||||||
|
|
||||||
let mut model = Unigram::from(&sentencepieces, 0);
|
let mut model = Unigram::from(&sentencepieces, 0).unwrap();
|
||||||
assert_eq!(model.encode("abc"), vec!["abc"]);
|
assert_eq!(model.encode("abc"), vec!["abc"]);
|
||||||
assert_eq!(model.encode("AB"), vec!["AB"]);
|
assert_eq!(model.encode("AB"), vec!["AB"]);
|
||||||
|
|
||||||
|
|||||||
@@ -1 +0,0 @@
|
|||||||
|
|
||||||
@@ -13,9 +13,7 @@ impl Serialize for Unigram {
|
|||||||
let mut model = serializer.serialize_struct("Unigram", 2)?;
|
let mut model = serializer.serialize_struct("Unigram", 2)?;
|
||||||
|
|
||||||
model.serialize_field("unk_id", &self.unk_id)?;
|
model.serialize_field("unk_id", &self.unk_id)?;
|
||||||
|
model.serialize_field("vocab", &self.vocab)?;
|
||||||
let vocab: Vec<(&String, &f64)> = self.vocab.iter().zip(self.scores.iter()).collect();
|
|
||||||
model.serialize_field("vocab", &vocab)?;
|
|
||||||
|
|
||||||
model.end()
|
model.end()
|
||||||
}
|
}
|
||||||
@@ -54,7 +52,8 @@ impl<'de> Visitor<'de> for UnigramVisitor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
match (vocab, unk_id) {
|
match (vocab, unk_id) {
|
||||||
(Some(vocab), Some(unk_id)) => Ok(Unigram::from(&vocab, unk_id)),
|
(Some(vocab), Some(unk_id)) => Ok(Unigram::from(&vocab, unk_id)
|
||||||
|
.map_err(|err| Error::custom(&format!("Unable to load vocab {:?}", err)))?),
|
||||||
(None, Some(_)) => Err(Error::custom("Missing vocab")),
|
(None, Some(_)) => Err(Error::custom("Missing vocab")),
|
||||||
(None, None) => Err(Error::custom("Missing vocab and unk_id")),
|
(None, None) => Err(Error::custom("Missing vocab and unk_id")),
|
||||||
(Some(_), None) => Err(Error::custom("Missing unk_id")),
|
(Some(_), None) => Err(Error::custom("Missing unk_id")),
|
||||||
@@ -69,7 +68,7 @@ mod test {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_serialization() {
|
fn test_serialization() {
|
||||||
let vocab = vec![("<unk>".to_string(), 0.0), ("a".to_string(), -0.5)];
|
let vocab = vec![("<unk>".to_string(), 0.0), ("a".to_string(), -0.5)];
|
||||||
let model = Unigram::from(&vocab, 0);
|
let model = Unigram::from(&vocab, 0).unwrap();
|
||||||
|
|
||||||
let data = serde_json::to_string(&model).unwrap();
|
let data = serde_json::to_string(&model).unwrap();
|
||||||
let reconstructed = serde_json::from_str(&data).unwrap();
|
let reconstructed = serde_json::from_str(&data).unwrap();
|
||||||
|
|||||||
@@ -167,7 +167,7 @@ impl UnigramTrainer {
|
|||||||
// true
|
// true
|
||||||
}
|
}
|
||||||
|
|
||||||
fn finalize(&self, model: Unigram, required_chars: HashSet<String>) -> Unigram {
|
fn finalize(&self, model: Unigram, required_chars: HashSet<String>) -> Result<Unigram> {
|
||||||
// let mut pieces: Vec<SentencePiece> =
|
// let mut pieces: Vec<SentencePiece> =
|
||||||
// Vec::with_capacity(self.vocab_size.try_into().unwrap());
|
// Vec::with_capacity(self.vocab_size.try_into().unwrap());
|
||||||
|
|
||||||
@@ -175,7 +175,7 @@ impl UnigramTrainer {
|
|||||||
let min_score_penalty_delta = 0.0001;
|
let min_score_penalty_delta = 0.0001;
|
||||||
|
|
||||||
let mut pieces: HashMap<String, f64> = HashMap::new();
|
let mut pieces: HashMap<String, f64> = HashMap::new();
|
||||||
let existing_pieces: HashMap<&String, f64> = model.iter().collect();
|
let existing_pieces: HashMap<String, f64> = model.iter().cloned().collect();
|
||||||
// XXX: Make sure bos, eos and unk exists and are ids 0, 1, 2
|
// XXX: Make sure bos, eos and unk exists and are ids 0, 1, 2
|
||||||
pieces.insert(self.unk_token.clone(), 0.0);
|
pieces.insert(self.unk_token.clone(), 0.0);
|
||||||
for c in required_chars {
|
for c in required_chars {
|
||||||
@@ -191,7 +191,7 @@ impl UnigramTrainer {
|
|||||||
for (token, score) in model.iter() {
|
for (token, score) in model.iter() {
|
||||||
match pieces.get(token) {
|
match pieces.get(token) {
|
||||||
Some(_) => continue,
|
Some(_) => continue,
|
||||||
None => pieces.insert(token.to_string(), score),
|
None => pieces.insert(token.to_string(), *score),
|
||||||
};
|
};
|
||||||
if pieces.len() == self.vocab_size as usize {
|
if pieces.len() == self.vocab_size as usize {
|
||||||
break;
|
break;
|
||||||
@@ -199,7 +199,7 @@ impl UnigramTrainer {
|
|||||||
}
|
}
|
||||||
let mut final_pieces: Vec<SentencePiece> = pieces.into_iter().collect();
|
let mut final_pieces: Vec<SentencePiece> = pieces.into_iter().collect();
|
||||||
final_pieces.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap());
|
final_pieces.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap());
|
||||||
Unigram::from(&final_pieces, 0)
|
Ok(Unigram::from(&final_pieces, 0).unwrap())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn required_chars(&self, word_counts: &[Sentence]) -> HashSet<String> {
|
fn required_chars(&self, word_counts: &[Sentence]) -> HashSet<String> {
|
||||||
@@ -546,7 +546,7 @@ impl UnigramTrainer {
|
|||||||
let expected_updates = expected_loops as usize * self.n_sub_iterations as usize;
|
let expected_updates = expected_loops as usize * self.n_sub_iterations as usize;
|
||||||
self.update_progress(&progress, expected_updates, "EM training");
|
self.update_progress(&progress, expected_updates, "EM training");
|
||||||
let required_chars = self.required_chars(&sentences);
|
let required_chars = self.required_chars(&sentences);
|
||||||
let mut model = Unigram::from(&pieces, 0);
|
let mut model = Unigram::from(&pieces, 0)?;
|
||||||
loop {
|
loop {
|
||||||
// Sub-EM iteration.
|
// Sub-EM iteration.
|
||||||
for _iter in 0..self.n_sub_iterations {
|
for _iter in 0..self.n_sub_iterations {
|
||||||
@@ -555,7 +555,7 @@ impl UnigramTrainer {
|
|||||||
|
|
||||||
// Executes M step.
|
// Executes M step.
|
||||||
pieces = self.run_m_step(&pieces, &expected);
|
pieces = self.run_m_step(&pieces, &expected);
|
||||||
model = Unigram::from(&pieces, 0);
|
model = Unigram::from(&pieces, 0)?;
|
||||||
// Useful comment for checking compatibility with spm
|
// Useful comment for checking compatibility with spm
|
||||||
println!(
|
println!(
|
||||||
"Em iter={} size={} obj={} num_tokens={} num_tokens/piece={}",
|
"Em iter={} size={} obj={} num_tokens={} num_tokens/piece={}",
|
||||||
@@ -578,12 +578,12 @@ impl UnigramTrainer {
|
|||||||
|
|
||||||
// Prunes pieces.
|
// Prunes pieces.
|
||||||
pieces = self.prune_sentence_pieces(&model, &pieces, &sentences);
|
pieces = self.prune_sentence_pieces(&model, &pieces, &sentences);
|
||||||
model = Unigram::from(&pieces, 0);
|
model = Unigram::from(&pieces, 0)?;
|
||||||
}
|
}
|
||||||
self.finalize_progress(&progress, expected_updates);
|
self.finalize_progress(&progress, expected_updates);
|
||||||
|
|
||||||
// Finally, adjusts the size of sentencepices to be |vocab_size|.
|
// Finally, adjusts the size of sentencepices to be |vocab_size|.
|
||||||
model = self.finalize(model, required_chars);
|
model = self.finalize(model, required_chars)?;
|
||||||
|
|
||||||
Ok((model, self.special_tokens.clone()))
|
Ok((model, self.special_tokens.clone()))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -41,32 +41,6 @@ fn test_unigram_from_file() {
|
|||||||
"。"
|
"。"
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
|
||||||
// Check it works with spm_export_vocab model.
|
|
||||||
let model = Unigram::load_spm(Path::new("data/unigram.model")).unwrap();
|
|
||||||
assert_eq!(
|
|
||||||
model
|
|
||||||
.tokenize(string)
|
|
||||||
.unwrap()
|
|
||||||
.iter()
|
|
||||||
.map(|tok| tok.value.clone())
|
|
||||||
.collect::<Vec<_>>(),
|
|
||||||
vec![
|
|
||||||
"吾輩",
|
|
||||||
"《",
|
|
||||||
"わが",
|
|
||||||
"はい",
|
|
||||||
"》",
|
|
||||||
"は",
|
|
||||||
"猫",
|
|
||||||
"である",
|
|
||||||
"。",
|
|
||||||
"名前",
|
|
||||||
"はまだ",
|
|
||||||
"無い",
|
|
||||||
"。"
|
|
||||||
]
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(not(debug_assertions))]
|
#[cfg(not(debug_assertions))]
|
||||||
|
|||||||
Reference in New Issue
Block a user