Unigram - Add special_tokens at the end of training + optional unk

This commit is contained in:
Anthony MOI
2020-10-23 12:18:57 -04:00
committed by Anthony MOI
parent 390ef2f9f3
commit 73b5da917f
7 changed files with 256 additions and 115 deletions

View File

@@ -416,9 +416,8 @@ pub struct PyUnigram {}
impl PyUnigram { impl PyUnigram {
#[new] #[new]
fn new(vocab: Option<Vec<(String, f64)>>, unk_id: Option<usize>) -> PyResult<(Self, PyModel)> { fn new(vocab: Option<Vec<(String, f64)>>, unk_id: Option<usize>) -> PyResult<(Self, PyModel)> {
if vocab.is_some() && unk_id.is_none() || vocab.is_none() && unk_id.is_some() {}
match (vocab, unk_id) { match (vocab, unk_id) {
(Some(vocab), Some(unk_id)) => { (Some(vocab), unk_id) => {
let model = Unigram::from(vocab, unk_id).map_err(|e| { let model = Unigram::from(vocab, unk_id).map_err(|e| {
exceptions::PyException::new_err(format!("Error while loading Unigram: {}", e)) exceptions::PyException::new_err(format!("Error while loading Unigram: {}", e))
})?; })?;

View File

@@ -42,3 +42,53 @@ class TestUnigram:
trainer = trainers.BpeTrainer(special_tokens=["<unk>"], show_progress=False) trainer = trainers.BpeTrainer(special_tokens=["<unk>"], show_progress=False)
bpe_tokenizer.train(trainer, [train_files["small"]]) bpe_tokenizer.train(trainer, [train_files["small"]])
def test_train_with_special_tokens(self):
filename = "tests/data/dummy-unigram-special_tokens-train.txt"
with open(filename, "w") as f:
f.write(
"""
[CLS] The Zen of Python, by Tim Peters [SEP]
[CLS] Beautiful is better than ugly. [SEP]
[CLS] Explicit is better than implicit. [SEP]
[CLS] Simple is better than complex. [SEP]
[CLS] Complex is better than complicated. [SEP]
[CLS] Flat is better than nested. [SEP]
[CLS] Sparse is better than dense. [SEP]
[CLS] Readability counts. [SEP]
[CLS] Special cases aren't special enough to break the rules. [SEP]
[CLS] Although practicality beats purity. [SEP]
[CLS] Errors should never pass silently. [SEP]
[CLS] Unless explicitly silenced. [SEP]
[CLS] In the face of ambiguity, refuse the temptation to guess. [SEP]
[CLS] There should be one-- and preferably only one --obvious way to do it. [SEP]
[CLS] Although that way may not be obvious at first unless you're Dutch. [SEP]
[CLS] Now is better than never. [SEP]
[CLS] Although never is often better than *right* now. [SEP]
[CLS] If the implementation is hard to explain, it's a bad idea. [SEP]
[CLS] If the implementation is easy to explain, it may be a good idea. [SEP]
[CLS] Namespaces are one honking great idea -- let's do more of those! [SEP]
"""
)
tokenizer = Tokenizer(models.Unigram())
trainer = trainers.UnigramTrainer(
show_progress=False, special_tokens=["[PAD]", "[SEP]", "[CLS]"], unk_token="[UNK]"
)
tokenizer.train(trainer, [filename])
assert tokenizer.encode("[CLS] This is a test [SEP]").tokens == [
"[CLS]",
" T",
"h",
"i",
"s",
" is ",
"a",
" ",
"t",
"es",
"t ",
"[SEP]",
]

View File

@@ -58,7 +58,6 @@ pub struct Lattice<'a> {
pub(super) end_nodes: Vec<Vec<NodeRef>>, pub(super) end_nodes: Vec<Vec<NodeRef>>,
bos_id: usize, bos_id: usize,
eos_id: usize, eos_id: usize,
unk_id: usize,
} }
impl std::fmt::Display for Lattice<'_> { impl std::fmt::Display for Lattice<'_> {
@@ -136,7 +135,7 @@ fn log_sum_exp(x: f64, y: f64, init_mode: bool) -> f64 {
} }
impl<'a> Lattice<'a> { impl<'a> Lattice<'a> {
pub fn from(sentence: &'a str, unk_id: usize, bos_id: usize, eos_id: usize) -> Lattice<'a> { pub fn from(sentence: &'a str, bos_id: usize, eos_id: usize) -> Lattice<'a> {
let len = sentence.bytes().count(); let len = sentence.bytes().count();
let k_reserved_node_size = 16; let k_reserved_node_size = 16;
// We are adding 2 tokens, bos and eos // We are adding 2 tokens, bos and eos
@@ -161,7 +160,6 @@ impl<'a> Lattice<'a> {
end_nodes, end_nodes,
bos_id, bos_id,
eos_id, eos_id,
unk_id,
} }
} }
@@ -439,16 +437,16 @@ mod tests {
#[test] #[test]
fn set_sentence() { fn set_sentence() {
let lattice = Lattice::from("", 0, 1, 2); let lattice = Lattice::from("", 1, 2);
assert_eq!(lattice.len(), 0); assert_eq!(lattice.len(), 0);
let lattice = Lattice::from("", 0, 1, 2); let lattice = Lattice::from("", 1, 2);
assert_eq!(lattice.len(), 0); assert_eq!(lattice.len(), 0);
assert_eq!(lattice.sentence(), ""); assert_eq!(lattice.sentence(), "");
assert_eq!(lattice.surface(0), ""); assert_eq!(lattice.surface(0), "");
let lattice = Lattice::from("test", 0, 1, 2); let lattice = Lattice::from("test", 1, 2);
assert_eq!(lattice.len(), 4); assert_eq!(lattice.len(), 4);
assert_eq!(lattice.sentence(), "test"); assert_eq!(lattice.sentence(), "test");
assert_eq!(lattice.surface(0), "test"); assert_eq!(lattice.surface(0), "test");
@@ -470,7 +468,7 @@ mod tests {
eos.borrow().id eos.borrow().id
); );
let lattice = Lattice::from("テストab", 0, 1, 2); let lattice = Lattice::from("テストab", 1, 2);
assert_eq!(lattice.len(), 11); assert_eq!(lattice.len(), 11);
assert_eq!(lattice.sentence(), "テストab"); assert_eq!(lattice.sentence(), "テストab");
assert_eq!(lattice.surface(0), "テストab"); assert_eq!(lattice.surface(0), "テストab");
@@ -482,7 +480,7 @@ mod tests {
#[test] #[test]
fn insert_test() { fn insert_test() {
let mut lattice = Lattice::from("ABあい", 0, 1, 2); let mut lattice = Lattice::from("ABあい", 1, 2);
lattice.insert(0, 1, 0.0, 3); lattice.insert(0, 1, 0.0, 3);
lattice.insert(1, 1, 0.0, 4); lattice.insert(1, 1, 0.0, 4);
@@ -573,7 +571,7 @@ mod tests {
#[test] #[test]
fn test_viterbi() { fn test_viterbi() {
let mut lattice = Lattice::from("ABC", 0, 1, 2); let mut lattice = Lattice::from("ABC", 1, 2);
assert_eq!(lattice.viterbi(), vec![]); assert_eq!(lattice.viterbi(), vec![]);
// Still incomplete // Still incomplete
lattice.insert(0, 1, 0.0, 3); lattice.insert(0, 1, 0.0, 3);
@@ -586,7 +584,7 @@ mod tests {
#[test] #[test]
fn test_viterbi2() { fn test_viterbi2() {
let mut lattice = Lattice::from("ABC", 0, 1, 2); let mut lattice = Lattice::from("ABC", 1, 2);
lattice.insert(0, 1, 0.0, 3); lattice.insert(0, 1, 0.0, 3);
lattice.insert(1, 1, 0.0, 4); lattice.insert(1, 1, 0.0, 4);
@@ -606,7 +604,7 @@ mod tests {
#[test] #[test]
fn test_nbest() { fn test_nbest() {
let mut lattice = Lattice::from("ABC", 0, 1, 2); let mut lattice = Lattice::from("ABC", 1, 2);
lattice.insert(0, 1, 0.0, 3); lattice.insert(0, 1, 0.0, 3);
lattice.insert(1, 1, 0.0, 4); lattice.insert(1, 1, 0.0, 4);
lattice.insert(2, 1, 0.0, 5); lattice.insert(2, 1, 0.0, 5);
@@ -641,7 +639,7 @@ mod tests {
#[test] #[test]
fn test_populate() { fn test_populate() {
let mut lattice = Lattice::from("ABC", 0, 1, 2); let mut lattice = Lattice::from("ABC", 1, 2);
lattice.insert(0, 1, 1.0, 3); // A lattice.insert(0, 1, 1.0, 3); // A
lattice.insert(1, 1, 1.2, 4); // B lattice.insert(1, 1, 1.2, 4); // B
lattice.insert(2, 1, 2.5, 5); // C lattice.insert(2, 1, 2.5, 5); // C

View File

@@ -18,7 +18,7 @@ pub struct Unigram {
cache: Cache<String, Vec<String>>, cache: Cache<String, Vec<String>>,
trie: Trie<u8>, trie: Trie<u8>,
pub min_score: f64, pub min_score: f64,
pub(super) unk_id: usize, pub(super) unk_id: Option<usize>,
pub(super) bos_id: usize, pub(super) bos_id: usize,
pub(super) eos_id: usize, pub(super) eos_id: usize,
@@ -54,7 +54,7 @@ impl Clone for Unigram {
impl std::fmt::Debug for Unigram { impl std::fmt::Debug for Unigram {
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
fmt.debug_struct("Unigram") fmt.debug_struct("Unigram")
.field("vocab", &self.vocab) .field("vocab", &self.vocab.len())
.field("unk_id", &self.unk_id) .field("unk_id", &self.unk_id)
.finish() .finish()
} }
@@ -66,6 +66,7 @@ static K_UNK_PENALTY: f64 = 10.0;
pub enum UnigramError { pub enum UnigramError {
EmptyVocabulary, EmptyVocabulary,
UnkIdNotInVocabulary, UnkIdNotInVocabulary,
MissingUnkId,
} }
impl std::fmt::Display for UnigramError { impl std::fmt::Display for UnigramError {
@@ -77,6 +78,9 @@ impl std::fmt::Display for UnigramError {
UnigramError::UnkIdNotInVocabulary => { UnigramError::UnkIdNotInVocabulary => {
write!(f, "The `unk_id` is larger than vocabulary size") write!(f, "The `unk_id` is larger than vocabulary size")
} }
UnigramError::MissingUnkId => {
write!(f, "Encountered an unknown token but `unk_id` is missing")
}
} }
} }
} }
@@ -86,7 +90,7 @@ impl std::error::Error for UnigramError {}
impl Default for Unigram { impl Default for Unigram {
fn default() -> Self { fn default() -> Self {
let vocab = vec![("<unk>".to_string(), 0.0)]; let vocab = vec![("<unk>".to_string(), 0.0)];
Self::from(vocab, 0).unwrap() Self::from(vocab, Some(0)).unwrap()
} }
} }
@@ -97,17 +101,19 @@ impl Unigram {
/// unk_id, is the index within the vocabulary. /// unk_id, is the index within the vocabulary.
/// For now `Unigram` *requires* at least `unk` because we might find a never seen char. /// For now `Unigram` *requires* at least `unk` because we might find a never seen char.
/// Further versions might allow that part to be hidden. /// Further versions might allow that part to be hidden.
pub fn from(vocab: Vec<(String, f64)>, unk_id: usize) -> Result<Self> { pub fn from(vocab: Vec<(String, f64)>, unk_id: Option<usize>) -> Result<Self> {
let n = vocab.len(); let n = vocab.len();
let mut token_to_ids: TokenMap = HashMap::new(); let mut token_to_ids: TokenMap = HashMap::new();
let mut builder = TrieBuilder::default(); let mut builder = TrieBuilder::default();
if let Some(unk_id) = unk_id {
if vocab.is_empty() { if vocab.is_empty() {
return Err(Box::new(UnigramError::EmptyVocabulary)); return Err(Box::new(UnigramError::EmptyVocabulary));
} }
if unk_id >= vocab.len() { if unk_id >= vocab.len() {
return Err(Box::new(UnigramError::UnkIdNotInVocabulary)); return Err(Box::new(UnigramError::UnkIdNotInVocabulary));
} }
}
let bos_id = n + 1; let bos_id = n + 1;
let eos_id = n + 2; let eos_id = n + 2;
@@ -187,7 +193,9 @@ impl Unigram {
} }
if !has_single_node { if !has_single_node {
lattice.insert(begin_pos, mblen, unk_score, self.unk_id); if let Some(unk_id) = self.unk_id {
lattice.insert(begin_pos, mblen, unk_score, unk_id);
}
} }
begin_pos += mblen begin_pos += mblen
} }
@@ -209,28 +217,28 @@ impl Unigram {
/// ("abc".to_string(), 5.0), /// ("abc".to_string(), 5.0),
/// ("abcd".to_string(), 10.0), /// ("abcd".to_string(), 10.0),
/// ]; /// ];
/// let model = Unigram::from(pieces, 0).unwrap(); /// let model = Unigram::from(pieces, Some(0)).unwrap();
/// let result = model.encode("abcdacdxx"); /// let result = model.encode("abcdacdxx").unwrap();
/// assert_eq!(result, vec!["abcd", "a", "cd", "xx"]); /// assert_eq!(result, vec!["abcd", "a", "cd", "xx"]);
/// ``` /// ```
pub fn encode(&self, sentence: &str) -> Vec<String> { pub fn encode(&self, sentence: &str) -> Result<Vec<String>> {
if sentence.is_empty() { if sentence.is_empty() {
return vec![]; return Ok(vec![]);
} }
if let Some(result) = self.cache.get(sentence) { if let Some(result) = self.cache.get(sentence) {
result.to_vec() Ok(result.to_vec())
} else { } else {
let result = if self.is_optimized { let result = if self.is_optimized {
self.encode_optimized(sentence) self.encode_optimized(sentence)?
} else { } else {
self.encode_unoptimized(sentence) self.encode_unoptimized(sentence)?
}; };
self.cache.set(sentence.to_owned(), result.clone()); self.cache.set(sentence.to_owned(), result.clone());
result Ok(result)
} }
} }
fn encode_optimized(&self, sentence: &str) -> Vec<String> { fn encode_optimized(&self, sentence: &str) -> Result<Vec<String>> {
// https://github.com/google/sentencepiece/blob/d48247191a6d50e469ed1a4a36e877befffd1851/src/unigram_model.cc#L600 // https://github.com/google/sentencepiece/blob/d48247191a6d50e469ed1a4a36e877befffd1851/src/unigram_model.cc#L600
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
struct BestPathNode { struct BestPathNode {
@@ -290,7 +298,7 @@ impl Unigram {
{ {
target_node.best_path_score = candidate_best_path_score; target_node.best_path_score = candidate_best_path_score;
target_node.starts_at = Some(starts_at); target_node.starts_at = Some(starts_at);
target_node.id = self.unk_id; target_node.id = self.unk_id.ok_or(UnigramError::MissingUnkId)?;
} }
} }
starts_at += mblen starts_at += mblen
@@ -301,16 +309,9 @@ impl Unigram {
while ends_at > 0 { while ends_at > 0 {
let node = &best_path_ends_at[ends_at]; let node = &best_path_ends_at[ends_at];
let starts_at = node.starts_at.unwrap(); let starts_at = node.starts_at.unwrap();
if self.fuse_unk && node.id == self.unk_id { if self.fuse_unk && node.id == self.unk_id.ok_or(UnigramError::MissingUnkId)? {
token.push( token.push(
String::from_utf8( String::from_utf8(sentence[starts_at..ends_at].as_bytes().to_vec()).unwrap(),
sentence
.bytes()
.skip(starts_at)
.take(ends_at - starts_at)
.collect(),
)
.unwrap(),
); );
} else { } else {
if !token.is_empty() { if !token.is_empty() {
@@ -319,14 +320,7 @@ impl Unigram {
token = vec![]; token = vec![];
} }
results.push( results.push(
String::from_utf8( String::from_utf8(sentence[starts_at..ends_at].as_bytes().to_vec()).unwrap(),
sentence
.bytes()
.skip(starts_at)
.take(ends_at - starts_at)
.collect(),
)
.unwrap(),
); );
} }
ends_at = starts_at; ends_at = starts_at;
@@ -336,18 +330,18 @@ impl Unigram {
results.push(token.concat()); results.push(token.concat());
} }
results.reverse(); results.reverse();
results Ok(results)
} }
fn encode_unoptimized(&self, sentence: &str) -> Vec<String> { fn encode_unoptimized(&self, sentence: &str) -> Result<Vec<String>> {
let mut lattice = Lattice::from(sentence, self.unk_id, self.bos_id, self.eos_id); let mut lattice = Lattice::from(sentence, self.bos_id, self.eos_id);
self.populate_nodes(&mut lattice); self.populate_nodes(&mut lattice);
if self.fuse_unk { if self.fuse_unk {
let mut results = vec![]; let mut results = vec![];
let mut token = String::new(); let mut token = String::new();
for node in lattice.viterbi().iter() { for node in lattice.viterbi().iter() {
let item = lattice.piece(&node.borrow()); let item = lattice.piece(&node.borrow());
if node.borrow().id == self.unk_id { if node.borrow().id == self.unk_id.ok_or(UnigramError::MissingUnkId)? {
token.push_str(&item); token.push_str(&item);
} else { } else {
if !token.is_empty() { if !token.is_empty() {
@@ -360,9 +354,9 @@ impl Unigram {
if !token.is_empty() { if !token.is_empty() {
results.push(token); results.push(token);
} }
results Ok(results)
} else { } else {
lattice.tokens() Ok(lattice.tokens())
} }
} }
@@ -416,21 +410,20 @@ impl Model for Unigram {
} }
fn tokenize(&self, sentence: &str) -> Result<Vec<Token>> { fn tokenize(&self, sentence: &str) -> Result<Vec<Token>> {
let tokens = self.encode(sentence); let str_tokens = self.encode(sentence)?;
let mut offset = 0; let mut offset = 0;
Ok(tokens let mut tokens = Vec::with_capacity(str_tokens.len());
.iter() for string in str_tokens {
.map(|string| { let id: u32 = match self.token_to_ids.get(&string) {
let id: u32 = match self.token_to_ids.get(string) {
Some(id) => *id, Some(id) => *id,
None => self.unk_id as u32, None => self.unk_id.ok_or(UnigramError::MissingUnkId)? as u32,
}; };
let len = string.len(); let len = string.len();
let offsets = (offset, offset + len); let offsets = (offset, offset + len);
offset += len; offset += len;
Token::new(id, string.to_string(), offsets) tokens.push(Token::new(id, string, offsets));
}) }
.collect()) Ok(tokens)
} }
fn token_to_id(&self, token: &str) -> Option<u32> { fn token_to_id(&self, token: &str) -> Option<u32> {
@@ -465,9 +458,9 @@ mod tests {
#[test] #[test]
fn test_populate_nodes_unk() { fn test_populate_nodes_unk() {
let pieces = vec![("<unk>".to_string(), 0.0)]; let pieces = vec![("<unk>".to_string(), 0.0)];
let model = Unigram::from(pieces, 0).unwrap(); let model = Unigram::from(pieces, Some(0)).unwrap();
let mut lattice = Lattice::from("abc", 0, model.bos_id, model.eos_id); let mut lattice = Lattice::from("abc", model.bos_id, model.eos_id);
model.populate_nodes(&mut lattice); model.populate_nodes(&mut lattice);
assert_eq!(lattice.begin_nodes[0].len(), 1); assert_eq!(lattice.begin_nodes[0].len(), 1);
@@ -490,9 +483,9 @@ mod tests {
("ab".to_string(), 0.3), ("ab".to_string(), 0.3),
("bc".to_string(), 0.4), ("bc".to_string(), 0.4),
]; ];
let model = Unigram::from(pieces, 0).unwrap(); let model = Unigram::from(pieces, Some(0)).unwrap();
let mut lattice = Lattice::from("abc", 0, model.bos_id, model.eos_id); let mut lattice = Lattice::from("abc", model.bos_id, model.eos_id);
model.populate_nodes(&mut lattice); model.populate_nodes(&mut lattice);
assert_eq!(lattice.begin_nodes[0].len(), 2); // a, ab assert_eq!(lattice.begin_nodes[0].len(), 2); // a, ab
@@ -527,8 +520,8 @@ mod tests {
("abcd".to_string(), 10.0), ("abcd".to_string(), 10.0),
]; ];
let model = Unigram::from(sentencepieces, 0).unwrap(); let model = Unigram::from(sentencepieces, Some(0)).unwrap();
let result = model.encode("abcd"); let result = model.encode("abcd").unwrap();
assert_eq!(result, vec!["abcd"]); assert_eq!(result, vec!["abcd"]);
} }
@@ -549,35 +542,41 @@ mod tests {
("qr".to_string(), -0.5), ("qr".to_string(), -0.5),
]; ];
let mut model = Unigram::from(sentencepieces, 0).unwrap(); let mut model = Unigram::from(sentencepieces, Some(0)).unwrap();
for is_optimized in &[true, false] { for is_optimized in &[true, false] {
model.set_optimized(*is_optimized); model.set_optimized(*is_optimized);
println!("IsOptimized {:?}", is_optimized); println!("IsOptimized {:?}", is_optimized);
assert_eq!(model.encode("abc"), vec!["abc"]); assert_eq!(model.encode("abc").unwrap(), vec!["abc"]);
assert_eq!(model.encode("AB"), vec!["AB"]); assert_eq!(model.encode("AB").unwrap(), vec!["AB"]);
model.set_fuse_unk(false); model.set_fuse_unk(false);
assert_eq!(model.encode("AB"), vec!["A", "B"]); assert_eq!(model.encode("AB").unwrap(), vec!["A", "B"]);
model.set_fuse_unk(true); model.set_fuse_unk(true);
assert_eq!(model.encode("AB"), vec!["AB"]); assert_eq!(model.encode("AB").unwrap(), vec!["AB"]);
assert_eq!(model.encode("abcd"), vec!["ab", "cd"]); assert_eq!(model.encode("abcd").unwrap(), vec!["ab", "cd"]);
assert_eq!(model.encode("abcc"), vec!["abc", "c"]); assert_eq!(model.encode("abcc").unwrap(), vec!["abc", "c"]);
assert_eq!( assert_eq!(
model.encode("xabcabaabcdd"), model.encode("xabcabaabcdd").unwrap(),
vec!["x", "abc", "ab", "a", "ab", "cd", "d"] vec!["x", "abc", "ab", "a", "ab", "cd", "d"]
); );
model.set_fuse_unk(false); model.set_fuse_unk(false);
assert_eq!(model.encode("xyz東京"), vec!["x", "y", "z", "", ""]); assert_eq!(
model.encode("xyz東京").unwrap(),
vec!["x", "y", "z", "", ""]
);
model.set_fuse_unk(true); model.set_fuse_unk(true);
assert_eq!(model.encode("xyz東京"), vec!["xyz東京"]); assert_eq!(model.encode("xyz東京").unwrap(), vec!["xyz東京"]);
// User encoded in original version // User encoded in original version
assert_eq!(model.encode("ABC"), vec!["ABC"]); assert_eq!(model.encode("ABC").unwrap(), vec!["ABC"]);
assert_eq!(model.encode("abABCcd"), vec!["ab", "ABC", "cd"]); assert_eq!(model.encode("abABCcd").unwrap(), vec!["ab", "ABC", "cd"]);
assert_eq!(model.encode("ababcdabcdcd"), vec!["ab", "abcdabcd", "cd"]); assert_eq!(
assert_eq!(model.encode("abqrcd"), vec!["ab", "q", "r", "cd"]); model.encode("ababcdabcdcd").unwrap(),
vec!["ab", "abcdabcd", "cd"]
);
assert_eq!(model.encode("abqrcd").unwrap(), vec!["ab", "q", "r", "cd"]);
} }
} }
} }

View File

@@ -52,11 +52,9 @@ impl<'de> Visitor<'de> for UnigramVisitor {
} }
} }
match (vocab, unk_id) { match (vocab, unk_id) {
(Some(vocab), Some(unk_id)) => Ok(Unigram::from(vocab, unk_id) (Some(vocab), unk_id) => Ok(Unigram::from(vocab, unk_id)
.map_err(|err| Error::custom(&format!("Unable to load vocab {:?}", err)))?), .map_err(|err| Error::custom(&format!("Unable to load vocab {:?}", err)))?),
(None, Some(_)) => Err(Error::custom("Missing vocab")), (None, _) => Err(Error::custom("Missing vocab")),
(None, None) => Err(Error::custom("Missing vocab and unk_id")),
(Some(_), None) => Err(Error::custom("Missing unk_id")),
} }
} }
} }
@@ -68,7 +66,7 @@ mod test {
#[test] #[test]
fn test_serialization() { fn test_serialization() {
let vocab = vec![("<unk>".to_string(), 0.0), ("a".to_string(), -0.5)]; let vocab = vec![("<unk>".to_string(), 0.0), ("a".to_string(), -0.5)];
let model = Unigram::from(vocab, 0).unwrap(); let model = Unigram::from(vocab, Some(0)).unwrap();
let data = serde_json::to_string(&model).unwrap(); let data = serde_json::to_string(&model).unwrap();
let reconstructed = serde_json::from_str(&data).unwrap(); let reconstructed = serde_json::from_str(&data).unwrap();
@@ -79,7 +77,18 @@ mod test {
#[test] #[test]
fn test_serialization_unk_id_not_zero() { fn test_serialization_unk_id_not_zero() {
let vocab = vec![("a".to_string(), -0.5), ("<unk>".to_string(), 0.0)]; let vocab = vec![("a".to_string(), -0.5), ("<unk>".to_string(), 0.0)];
let model = Unigram::from(vocab, 1).unwrap(); let model = Unigram::from(vocab, Some(1)).unwrap();
let data = serde_json::to_string(&model).unwrap();
let reconstructed = serde_json::from_str(&data).unwrap();
assert_eq!(model, reconstructed);
}
#[test]
fn test_serialization_no_unk_id() {
let vocab = vec![("a".to_string(), -0.5)];
let model = Unigram::from(vocab, None).unwrap();
let data = serde_json::to_string(&model).unwrap(); let data = serde_json::to_string(&model).unwrap();
let reconstructed = serde_json::from_str(&data).unwrap(); let reconstructed = serde_json::from_str(&data).unwrap();

View File

@@ -51,8 +51,8 @@ pub struct UnigramTrainer {
#[builder(default = "HashSet::new()")] #[builder(default = "HashSet::new()")]
initial_alphabet: HashSet<char>, initial_alphabet: HashSet<char>,
#[builder(default = "String::from(\"<unk>\")")] #[builder(default = "None")]
unk_token: String, unk_token: Option<String>,
#[builder(default = "16")] #[builder(default = "16")]
max_piece_length: usize, max_piece_length: usize,
@@ -122,7 +122,33 @@ impl UnigramTrainer {
} }
} }
pieces.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap()); pieces.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap());
Unigram::from(pieces, 0)
// Insert the necessary tokens
let (unk_id, need_add_unk) = if let Some(ref unk) = self.unk_token {
let unk_id = self.special_tokens.iter().enumerate().find_map(|(i, t)| {
if t.content == *unk {
Some(i)
} else {
None
}
});
match unk_id {
Some(id) => (Some(id), false),
None => (Some(0), true),
}
} else {
(None, false)
};
let mut special_tokens = self
.special_tokens
.iter()
.map(|t| (t.content.clone(), 0.0))
.collect::<Vec<_>>();
if need_add_unk {
special_tokens.insert(0, (self.unk_token.clone().unwrap(), 0.0));
}
Unigram::from(special_tokens.into_iter().chain(pieces).collect(), unk_id)
} }
fn required_chars(&self, word_counts: &[Sentence]) -> HashSet<String> { fn required_chars(&self, word_counts: &[Sentence]) -> HashSet<String> {
@@ -230,7 +256,7 @@ impl UnigramTrainer {
always_keep[id] = false; always_keep[id] = false;
continue; continue;
} }
let mut lattice = Lattice::from(token, 0, bos_id, eos_id); let mut lattice = Lattice::from(token, bos_id, eos_id);
model.populate_nodes(&mut lattice); model.populate_nodes(&mut lattice);
let nbests = lattice.nbest(2); let nbests = lattice.nbest(2);
@@ -255,7 +281,7 @@ impl UnigramTrainer {
let mut inverted: Vec<Vec<usize>> = vec![Vec::new(); pieces.len()]; let mut inverted: Vec<Vec<usize>> = vec![Vec::new(); pieces.len()];
// TODO reparallelize this // TODO reparallelize this
for (i, (sentence, count)) in sentences.iter().enumerate() { for (i, (sentence, count)) in sentences.iter().enumerate() {
let mut lattice = Lattice::from(sentence, 0, bos_id, eos_id); let mut lattice = Lattice::from(sentence, bos_id, eos_id);
model.populate_nodes(&mut lattice); model.populate_nodes(&mut lattice);
vsum += *count as f64; vsum += *count as f64;
for node_ref in lattice.viterbi() { for node_ref in lattice.viterbi() {
@@ -365,7 +391,7 @@ impl UnigramTrainer {
// TODO reparallelize this. // TODO reparallelize this.
for (string, freq) in sentences { for (string, freq) in sentences {
let mut lattice = Lattice::from(string, model.unk_id, model.bos_id, model.eos_id); let mut lattice = Lattice::from(string, model.bos_id, model.eos_id);
model.populate_nodes(&mut lattice); model.populate_nodes(&mut lattice);
let z: f64 = lattice.populate_marginal(*freq as f64, &mut expected); let z: f64 = lattice.populate_marginal(*freq as f64, &mut expected);
ntokens += lattice.viterbi().len() as u32; ntokens += lattice.viterbi().len() as u32;
@@ -422,13 +448,7 @@ impl UnigramTrainer {
self.update_progress(&progress, sentences.len(), "Suffix array seeds"); self.update_progress(&progress, sentences.len(), "Suffix array seeds");
let mut pieces: Vec<SentencePiece> = let mut pieces: Vec<SentencePiece> =
Vec::with_capacity(self.vocab_size.try_into().unwrap()); Vec::with_capacity(self.vocab_size.try_into().unwrap());
// XXX: Make sure unk exists and are ids 0
pieces.push((self.unk_token.clone(), f64::NAN));
pieces.extend(
self.special_tokens
.iter()
.map(|tok| (tok.content.clone(), f64::NAN)),
);
pieces.extend(self.make_seed_sentence_pieces(&sentences, &progress)?); pieces.extend(self.make_seed_sentence_pieces(&sentences, &progress)?);
self.finalize_progress(&progress, sentences.len()); self.finalize_progress(&progress, sentences.len());
@@ -452,7 +472,7 @@ impl UnigramTrainer {
let expected_updates = expected_loops as usize * self.n_sub_iterations as usize; let expected_updates = expected_loops as usize * self.n_sub_iterations as usize;
self.update_progress(&progress, expected_updates, "EM training"); self.update_progress(&progress, expected_updates, "EM training");
let required_chars = self.required_chars(&sentences); let required_chars = self.required_chars(&sentences);
let mut model = Unigram::from(pieces.clone(), 0)?; let mut model = Unigram::from(pieces.clone(), None)?;
loop { loop {
// Sub-EM iteration. // Sub-EM iteration.
for _iter in 0..self.n_sub_iterations { for _iter in 0..self.n_sub_iterations {
@@ -461,7 +481,7 @@ impl UnigramTrainer {
// Executes M step. // Executes M step.
pieces = self.run_m_step(&pieces, &expected); pieces = self.run_m_step(&pieces, &expected);
model = Unigram::from(pieces.clone(), 0)?; model = Unigram::from(pieces.clone(), None)?;
// Useful comment for checking compatibility with spm // Useful comment for checking compatibility with spm
debug!( debug!(
@@ -485,7 +505,7 @@ impl UnigramTrainer {
// Prunes pieces. // Prunes pieces.
pieces = self.prune_sentence_pieces(&model, &pieces, &sentences); pieces = self.prune_sentence_pieces(&model, &pieces, &sentences);
model = Unigram::from(pieces.clone(), 0)?; model = Unigram::from(pieces.clone(), None)?;
} }
self.finalize_progress(&progress, expected_updates); self.finalize_progress(&progress, expected_updates);
@@ -598,6 +618,72 @@ mod tests {
); );
} }
#[test]
fn test_unk_token() {
// 1. Should add `unk_token` as first special token
let trainer = UnigramTrainerBuilder::default()
.show_progress(false)
.special_tokens(vec![
AddedToken::from("[SEP]", true),
AddedToken::from("[CLS]", true),
])
.unk_token(Some("[UNK]".into()))
.build()
.unwrap();
let (unigram, _) = trainer
.train(HashMap::from_iter(vec![
("The".into(), 12),
("are".into(), 11),
]))
.unwrap();
let mut pieces = unigram.iter();
assert_eq!(pieces.next(), Some(&("[UNK]".into(), 0.0)));
assert_eq!(pieces.next(), Some(&("[SEP]".into(), 0.0)));
assert_eq!(pieces.next(), Some(&("[CLS]".into(), 0.0)));
// 2. Let it where it is
let trainer = UnigramTrainerBuilder::default()
.show_progress(false)
.special_tokens(vec![
AddedToken::from("[SEP]", true),
AddedToken::from("[CLS]", true),
AddedToken::from("[UNK]", true),
])
.unk_token(Some("[UNK]".into()))
.build()
.unwrap();
let (unigram, _) = trainer
.train(HashMap::from_iter(vec![
("The".into(), 12),
("are".into(), 11),
]))
.unwrap();
let mut pieces = unigram.iter();
assert_eq!(pieces.next(), Some(&("[SEP]".into(), 0.0)));
assert_eq!(pieces.next(), Some(&("[CLS]".into(), 0.0)));
assert_eq!(pieces.next(), Some(&("[UNK]".into(), 0.0)));
// 3. Don't put it there if not needed
let trainer = UnigramTrainerBuilder::default()
.show_progress(false)
.build()
.unwrap();
let (unigram, _) = trainer
.train(HashMap::from_iter(vec![
("The".into(), 12),
("are".into(), 11),
]))
.unwrap();
let mut pieces = unigram.iter();
assert_eq!(pieces.next().unwrap().0, "e".to_string());
}
#[test] #[test]
fn test_special_tokens() { fn test_special_tokens() {
let trainer = UnigramTrainerBuilder::default() let trainer = UnigramTrainerBuilder::default()
@@ -617,7 +703,6 @@ mod tests {
.unwrap(); .unwrap();
let mut pieces = unigram.iter(); let mut pieces = unigram.iter();
assert_eq!(pieces.next(), Some(&("<unk>".into(), 0.0)));
assert_eq!(pieces.next(), Some(&("[SEP]".into(), 0.0))); assert_eq!(pieces.next(), Some(&("[SEP]".into(), 0.0)));
assert_eq!(pieces.next(), Some(&("[CLS]".into(), 0.0))); assert_eq!(pieces.next(), Some(&("[CLS]".into(), 0.0)));
} }

View File

@@ -52,10 +52,11 @@ fn test_train_unigram_from_file() {
let trainer = UnigramTrainer::builder() let trainer = UnigramTrainer::builder()
.show_progress(false) .show_progress(false)
.unk_token(Some("<unk>".into()))
.build() .build()
.unwrap(); .unwrap();
let (model, _) = trainer.train(word_counts).unwrap(); let (model, _) = trainer.train(word_counts).unwrap();
assert_eq!(model.get_vocab_size(), 719); assert_eq!(model.get_vocab_size(), 717);
} }
#[cfg(not(debug_assertions))] #[cfg(not(debug_assertions))]