mirror of
https://github.com/mii443/tokenizers.git
synced 2025-12-03 19:28:20 +00:00
Unigram - Add special_tokens at the end of training + optional unk
This commit is contained in:
@@ -416,9 +416,8 @@ pub struct PyUnigram {}
|
|||||||
impl PyUnigram {
|
impl PyUnigram {
|
||||||
#[new]
|
#[new]
|
||||||
fn new(vocab: Option<Vec<(String, f64)>>, unk_id: Option<usize>) -> PyResult<(Self, PyModel)> {
|
fn new(vocab: Option<Vec<(String, f64)>>, unk_id: Option<usize>) -> PyResult<(Self, PyModel)> {
|
||||||
if vocab.is_some() && unk_id.is_none() || vocab.is_none() && unk_id.is_some() {}
|
|
||||||
match (vocab, unk_id) {
|
match (vocab, unk_id) {
|
||||||
(Some(vocab), Some(unk_id)) => {
|
(Some(vocab), unk_id) => {
|
||||||
let model = Unigram::from(vocab, unk_id).map_err(|e| {
|
let model = Unigram::from(vocab, unk_id).map_err(|e| {
|
||||||
exceptions::PyException::new_err(format!("Error while loading Unigram: {}", e))
|
exceptions::PyException::new_err(format!("Error while loading Unigram: {}", e))
|
||||||
})?;
|
})?;
|
||||||
|
|||||||
@@ -42,3 +42,53 @@ class TestUnigram:
|
|||||||
|
|
||||||
trainer = trainers.BpeTrainer(special_tokens=["<unk>"], show_progress=False)
|
trainer = trainers.BpeTrainer(special_tokens=["<unk>"], show_progress=False)
|
||||||
bpe_tokenizer.train(trainer, [train_files["small"]])
|
bpe_tokenizer.train(trainer, [train_files["small"]])
|
||||||
|
|
||||||
|
def test_train_with_special_tokens(self):
|
||||||
|
filename = "tests/data/dummy-unigram-special_tokens-train.txt"
|
||||||
|
with open(filename, "w") as f:
|
||||||
|
f.write(
|
||||||
|
"""
|
||||||
|
[CLS] The Zen of Python, by Tim Peters [SEP]
|
||||||
|
[CLS] Beautiful is better than ugly. [SEP]
|
||||||
|
[CLS] Explicit is better than implicit. [SEP]
|
||||||
|
[CLS] Simple is better than complex. [SEP]
|
||||||
|
[CLS] Complex is better than complicated. [SEP]
|
||||||
|
[CLS] Flat is better than nested. [SEP]
|
||||||
|
[CLS] Sparse is better than dense. [SEP]
|
||||||
|
[CLS] Readability counts. [SEP]
|
||||||
|
[CLS] Special cases aren't special enough to break the rules. [SEP]
|
||||||
|
[CLS] Although practicality beats purity. [SEP]
|
||||||
|
[CLS] Errors should never pass silently. [SEP]
|
||||||
|
[CLS] Unless explicitly silenced. [SEP]
|
||||||
|
[CLS] In the face of ambiguity, refuse the temptation to guess. [SEP]
|
||||||
|
[CLS] There should be one-- and preferably only one --obvious way to do it. [SEP]
|
||||||
|
[CLS] Although that way may not be obvious at first unless you're Dutch. [SEP]
|
||||||
|
[CLS] Now is better than never. [SEP]
|
||||||
|
[CLS] Although never is often better than *right* now. [SEP]
|
||||||
|
[CLS] If the implementation is hard to explain, it's a bad idea. [SEP]
|
||||||
|
[CLS] If the implementation is easy to explain, it may be a good idea. [SEP]
|
||||||
|
[CLS] Namespaces are one honking great idea -- let's do more of those! [SEP]
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenizer = Tokenizer(models.Unigram())
|
||||||
|
trainer = trainers.UnigramTrainer(
|
||||||
|
show_progress=False, special_tokens=["[PAD]", "[SEP]", "[CLS]"], unk_token="[UNK]"
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenizer.train(trainer, [filename])
|
||||||
|
|
||||||
|
assert tokenizer.encode("[CLS] This is a test [SEP]").tokens == [
|
||||||
|
"[CLS]",
|
||||||
|
" T",
|
||||||
|
"h",
|
||||||
|
"i",
|
||||||
|
"s",
|
||||||
|
" is ",
|
||||||
|
"a",
|
||||||
|
" ",
|
||||||
|
"t",
|
||||||
|
"es",
|
||||||
|
"t ",
|
||||||
|
"[SEP]",
|
||||||
|
]
|
||||||
|
|||||||
@@ -58,7 +58,6 @@ pub struct Lattice<'a> {
|
|||||||
pub(super) end_nodes: Vec<Vec<NodeRef>>,
|
pub(super) end_nodes: Vec<Vec<NodeRef>>,
|
||||||
bos_id: usize,
|
bos_id: usize,
|
||||||
eos_id: usize,
|
eos_id: usize,
|
||||||
unk_id: usize,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::fmt::Display for Lattice<'_> {
|
impl std::fmt::Display for Lattice<'_> {
|
||||||
@@ -136,7 +135,7 @@ fn log_sum_exp(x: f64, y: f64, init_mode: bool) -> f64 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> Lattice<'a> {
|
impl<'a> Lattice<'a> {
|
||||||
pub fn from(sentence: &'a str, unk_id: usize, bos_id: usize, eos_id: usize) -> Lattice<'a> {
|
pub fn from(sentence: &'a str, bos_id: usize, eos_id: usize) -> Lattice<'a> {
|
||||||
let len = sentence.bytes().count();
|
let len = sentence.bytes().count();
|
||||||
let k_reserved_node_size = 16;
|
let k_reserved_node_size = 16;
|
||||||
// We are adding 2 tokens, bos and eos
|
// We are adding 2 tokens, bos and eos
|
||||||
@@ -161,7 +160,6 @@ impl<'a> Lattice<'a> {
|
|||||||
end_nodes,
|
end_nodes,
|
||||||
bos_id,
|
bos_id,
|
||||||
eos_id,
|
eos_id,
|
||||||
unk_id,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -439,16 +437,16 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn set_sentence() {
|
fn set_sentence() {
|
||||||
let lattice = Lattice::from("", 0, 1, 2);
|
let lattice = Lattice::from("", 1, 2);
|
||||||
|
|
||||||
assert_eq!(lattice.len(), 0);
|
assert_eq!(lattice.len(), 0);
|
||||||
|
|
||||||
let lattice = Lattice::from("", 0, 1, 2);
|
let lattice = Lattice::from("", 1, 2);
|
||||||
assert_eq!(lattice.len(), 0);
|
assert_eq!(lattice.len(), 0);
|
||||||
assert_eq!(lattice.sentence(), "");
|
assert_eq!(lattice.sentence(), "");
|
||||||
assert_eq!(lattice.surface(0), "");
|
assert_eq!(lattice.surface(0), "");
|
||||||
|
|
||||||
let lattice = Lattice::from("test", 0, 1, 2);
|
let lattice = Lattice::from("test", 1, 2);
|
||||||
assert_eq!(lattice.len(), 4);
|
assert_eq!(lattice.len(), 4);
|
||||||
assert_eq!(lattice.sentence(), "test");
|
assert_eq!(lattice.sentence(), "test");
|
||||||
assert_eq!(lattice.surface(0), "test");
|
assert_eq!(lattice.surface(0), "test");
|
||||||
@@ -470,7 +468,7 @@ mod tests {
|
|||||||
eos.borrow().id
|
eos.borrow().id
|
||||||
);
|
);
|
||||||
|
|
||||||
let lattice = Lattice::from("テストab", 0, 1, 2);
|
let lattice = Lattice::from("テストab", 1, 2);
|
||||||
assert_eq!(lattice.len(), 11);
|
assert_eq!(lattice.len(), 11);
|
||||||
assert_eq!(lattice.sentence(), "テストab");
|
assert_eq!(lattice.sentence(), "テストab");
|
||||||
assert_eq!(lattice.surface(0), "テストab");
|
assert_eq!(lattice.surface(0), "テストab");
|
||||||
@@ -482,7 +480,7 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn insert_test() {
|
fn insert_test() {
|
||||||
let mut lattice = Lattice::from("ABあい", 0, 1, 2);
|
let mut lattice = Lattice::from("ABあい", 1, 2);
|
||||||
|
|
||||||
lattice.insert(0, 1, 0.0, 3);
|
lattice.insert(0, 1, 0.0, 3);
|
||||||
lattice.insert(1, 1, 0.0, 4);
|
lattice.insert(1, 1, 0.0, 4);
|
||||||
@@ -573,7 +571,7 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_viterbi() {
|
fn test_viterbi() {
|
||||||
let mut lattice = Lattice::from("ABC", 0, 1, 2);
|
let mut lattice = Lattice::from("ABC", 1, 2);
|
||||||
assert_eq!(lattice.viterbi(), vec![]);
|
assert_eq!(lattice.viterbi(), vec![]);
|
||||||
// Still incomplete
|
// Still incomplete
|
||||||
lattice.insert(0, 1, 0.0, 3);
|
lattice.insert(0, 1, 0.0, 3);
|
||||||
@@ -586,7 +584,7 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_viterbi2() {
|
fn test_viterbi2() {
|
||||||
let mut lattice = Lattice::from("ABC", 0, 1, 2);
|
let mut lattice = Lattice::from("ABC", 1, 2);
|
||||||
|
|
||||||
lattice.insert(0, 1, 0.0, 3);
|
lattice.insert(0, 1, 0.0, 3);
|
||||||
lattice.insert(1, 1, 0.0, 4);
|
lattice.insert(1, 1, 0.0, 4);
|
||||||
@@ -606,7 +604,7 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_nbest() {
|
fn test_nbest() {
|
||||||
let mut lattice = Lattice::from("ABC", 0, 1, 2);
|
let mut lattice = Lattice::from("ABC", 1, 2);
|
||||||
lattice.insert(0, 1, 0.0, 3);
|
lattice.insert(0, 1, 0.0, 3);
|
||||||
lattice.insert(1, 1, 0.0, 4);
|
lattice.insert(1, 1, 0.0, 4);
|
||||||
lattice.insert(2, 1, 0.0, 5);
|
lattice.insert(2, 1, 0.0, 5);
|
||||||
@@ -641,7 +639,7 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_populate() {
|
fn test_populate() {
|
||||||
let mut lattice = Lattice::from("ABC", 0, 1, 2);
|
let mut lattice = Lattice::from("ABC", 1, 2);
|
||||||
lattice.insert(0, 1, 1.0, 3); // A
|
lattice.insert(0, 1, 1.0, 3); // A
|
||||||
lattice.insert(1, 1, 1.2, 4); // B
|
lattice.insert(1, 1, 1.2, 4); // B
|
||||||
lattice.insert(2, 1, 2.5, 5); // C
|
lattice.insert(2, 1, 2.5, 5); // C
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ pub struct Unigram {
|
|||||||
cache: Cache<String, Vec<String>>,
|
cache: Cache<String, Vec<String>>,
|
||||||
trie: Trie<u8>,
|
trie: Trie<u8>,
|
||||||
pub min_score: f64,
|
pub min_score: f64,
|
||||||
pub(super) unk_id: usize,
|
pub(super) unk_id: Option<usize>,
|
||||||
pub(super) bos_id: usize,
|
pub(super) bos_id: usize,
|
||||||
pub(super) eos_id: usize,
|
pub(super) eos_id: usize,
|
||||||
|
|
||||||
@@ -54,7 +54,7 @@ impl Clone for Unigram {
|
|||||||
impl std::fmt::Debug for Unigram {
|
impl std::fmt::Debug for Unigram {
|
||||||
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
|
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||||
fmt.debug_struct("Unigram")
|
fmt.debug_struct("Unigram")
|
||||||
.field("vocab", &self.vocab)
|
.field("vocab", &self.vocab.len())
|
||||||
.field("unk_id", &self.unk_id)
|
.field("unk_id", &self.unk_id)
|
||||||
.finish()
|
.finish()
|
||||||
}
|
}
|
||||||
@@ -66,6 +66,7 @@ static K_UNK_PENALTY: f64 = 10.0;
|
|||||||
pub enum UnigramError {
|
pub enum UnigramError {
|
||||||
EmptyVocabulary,
|
EmptyVocabulary,
|
||||||
UnkIdNotInVocabulary,
|
UnkIdNotInVocabulary,
|
||||||
|
MissingUnkId,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::fmt::Display for UnigramError {
|
impl std::fmt::Display for UnigramError {
|
||||||
@@ -77,6 +78,9 @@ impl std::fmt::Display for UnigramError {
|
|||||||
UnigramError::UnkIdNotInVocabulary => {
|
UnigramError::UnkIdNotInVocabulary => {
|
||||||
write!(f, "The `unk_id` is larger than vocabulary size")
|
write!(f, "The `unk_id` is larger than vocabulary size")
|
||||||
}
|
}
|
||||||
|
UnigramError::MissingUnkId => {
|
||||||
|
write!(f, "Encountered an unknown token but `unk_id` is missing")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -86,7 +90,7 @@ impl std::error::Error for UnigramError {}
|
|||||||
impl Default for Unigram {
|
impl Default for Unigram {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
let vocab = vec![("<unk>".to_string(), 0.0)];
|
let vocab = vec![("<unk>".to_string(), 0.0)];
|
||||||
Self::from(vocab, 0).unwrap()
|
Self::from(vocab, Some(0)).unwrap()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -97,17 +101,19 @@ impl Unigram {
|
|||||||
/// unk_id, is the index within the vocabulary.
|
/// unk_id, is the index within the vocabulary.
|
||||||
/// For now `Unigram` *requires* at least `unk` because we might find a never seen char.
|
/// For now `Unigram` *requires* at least `unk` because we might find a never seen char.
|
||||||
/// Further versions might allow that part to be hidden.
|
/// Further versions might allow that part to be hidden.
|
||||||
pub fn from(vocab: Vec<(String, f64)>, unk_id: usize) -> Result<Self> {
|
pub fn from(vocab: Vec<(String, f64)>, unk_id: Option<usize>) -> Result<Self> {
|
||||||
let n = vocab.len();
|
let n = vocab.len();
|
||||||
let mut token_to_ids: TokenMap = HashMap::new();
|
let mut token_to_ids: TokenMap = HashMap::new();
|
||||||
let mut builder = TrieBuilder::default();
|
let mut builder = TrieBuilder::default();
|
||||||
|
|
||||||
|
if let Some(unk_id) = unk_id {
|
||||||
if vocab.is_empty() {
|
if vocab.is_empty() {
|
||||||
return Err(Box::new(UnigramError::EmptyVocabulary));
|
return Err(Box::new(UnigramError::EmptyVocabulary));
|
||||||
}
|
}
|
||||||
if unk_id >= vocab.len() {
|
if unk_id >= vocab.len() {
|
||||||
return Err(Box::new(UnigramError::UnkIdNotInVocabulary));
|
return Err(Box::new(UnigramError::UnkIdNotInVocabulary));
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let bos_id = n + 1;
|
let bos_id = n + 1;
|
||||||
let eos_id = n + 2;
|
let eos_id = n + 2;
|
||||||
@@ -187,7 +193,9 @@ impl Unigram {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !has_single_node {
|
if !has_single_node {
|
||||||
lattice.insert(begin_pos, mblen, unk_score, self.unk_id);
|
if let Some(unk_id) = self.unk_id {
|
||||||
|
lattice.insert(begin_pos, mblen, unk_score, unk_id);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
begin_pos += mblen
|
begin_pos += mblen
|
||||||
}
|
}
|
||||||
@@ -209,28 +217,28 @@ impl Unigram {
|
|||||||
/// ("abc".to_string(), 5.0),
|
/// ("abc".to_string(), 5.0),
|
||||||
/// ("abcd".to_string(), 10.0),
|
/// ("abcd".to_string(), 10.0),
|
||||||
/// ];
|
/// ];
|
||||||
/// let model = Unigram::from(pieces, 0).unwrap();
|
/// let model = Unigram::from(pieces, Some(0)).unwrap();
|
||||||
/// let result = model.encode("abcdacdxx");
|
/// let result = model.encode("abcdacdxx").unwrap();
|
||||||
/// assert_eq!(result, vec!["abcd", "a", "cd", "xx"]);
|
/// assert_eq!(result, vec!["abcd", "a", "cd", "xx"]);
|
||||||
/// ```
|
/// ```
|
||||||
pub fn encode(&self, sentence: &str) -> Vec<String> {
|
pub fn encode(&self, sentence: &str) -> Result<Vec<String>> {
|
||||||
if sentence.is_empty() {
|
if sentence.is_empty() {
|
||||||
return vec![];
|
return Ok(vec![]);
|
||||||
}
|
}
|
||||||
if let Some(result) = self.cache.get(sentence) {
|
if let Some(result) = self.cache.get(sentence) {
|
||||||
result.to_vec()
|
Ok(result.to_vec())
|
||||||
} else {
|
} else {
|
||||||
let result = if self.is_optimized {
|
let result = if self.is_optimized {
|
||||||
self.encode_optimized(sentence)
|
self.encode_optimized(sentence)?
|
||||||
} else {
|
} else {
|
||||||
self.encode_unoptimized(sentence)
|
self.encode_unoptimized(sentence)?
|
||||||
};
|
};
|
||||||
self.cache.set(sentence.to_owned(), result.clone());
|
self.cache.set(sentence.to_owned(), result.clone());
|
||||||
result
|
Ok(result)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn encode_optimized(&self, sentence: &str) -> Vec<String> {
|
fn encode_optimized(&self, sentence: &str) -> Result<Vec<String>> {
|
||||||
// https://github.com/google/sentencepiece/blob/d48247191a6d50e469ed1a4a36e877befffd1851/src/unigram_model.cc#L600
|
// https://github.com/google/sentencepiece/blob/d48247191a6d50e469ed1a4a36e877befffd1851/src/unigram_model.cc#L600
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
struct BestPathNode {
|
struct BestPathNode {
|
||||||
@@ -290,7 +298,7 @@ impl Unigram {
|
|||||||
{
|
{
|
||||||
target_node.best_path_score = candidate_best_path_score;
|
target_node.best_path_score = candidate_best_path_score;
|
||||||
target_node.starts_at = Some(starts_at);
|
target_node.starts_at = Some(starts_at);
|
||||||
target_node.id = self.unk_id;
|
target_node.id = self.unk_id.ok_or(UnigramError::MissingUnkId)?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
starts_at += mblen
|
starts_at += mblen
|
||||||
@@ -301,16 +309,9 @@ impl Unigram {
|
|||||||
while ends_at > 0 {
|
while ends_at > 0 {
|
||||||
let node = &best_path_ends_at[ends_at];
|
let node = &best_path_ends_at[ends_at];
|
||||||
let starts_at = node.starts_at.unwrap();
|
let starts_at = node.starts_at.unwrap();
|
||||||
if self.fuse_unk && node.id == self.unk_id {
|
if self.fuse_unk && node.id == self.unk_id.ok_or(UnigramError::MissingUnkId)? {
|
||||||
token.push(
|
token.push(
|
||||||
String::from_utf8(
|
String::from_utf8(sentence[starts_at..ends_at].as_bytes().to_vec()).unwrap(),
|
||||||
sentence
|
|
||||||
.bytes()
|
|
||||||
.skip(starts_at)
|
|
||||||
.take(ends_at - starts_at)
|
|
||||||
.collect(),
|
|
||||||
)
|
|
||||||
.unwrap(),
|
|
||||||
);
|
);
|
||||||
} else {
|
} else {
|
||||||
if !token.is_empty() {
|
if !token.is_empty() {
|
||||||
@@ -319,14 +320,7 @@ impl Unigram {
|
|||||||
token = vec![];
|
token = vec![];
|
||||||
}
|
}
|
||||||
results.push(
|
results.push(
|
||||||
String::from_utf8(
|
String::from_utf8(sentence[starts_at..ends_at].as_bytes().to_vec()).unwrap(),
|
||||||
sentence
|
|
||||||
.bytes()
|
|
||||||
.skip(starts_at)
|
|
||||||
.take(ends_at - starts_at)
|
|
||||||
.collect(),
|
|
||||||
)
|
|
||||||
.unwrap(),
|
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
ends_at = starts_at;
|
ends_at = starts_at;
|
||||||
@@ -336,18 +330,18 @@ impl Unigram {
|
|||||||
results.push(token.concat());
|
results.push(token.concat());
|
||||||
}
|
}
|
||||||
results.reverse();
|
results.reverse();
|
||||||
results
|
Ok(results)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn encode_unoptimized(&self, sentence: &str) -> Vec<String> {
|
fn encode_unoptimized(&self, sentence: &str) -> Result<Vec<String>> {
|
||||||
let mut lattice = Lattice::from(sentence, self.unk_id, self.bos_id, self.eos_id);
|
let mut lattice = Lattice::from(sentence, self.bos_id, self.eos_id);
|
||||||
self.populate_nodes(&mut lattice);
|
self.populate_nodes(&mut lattice);
|
||||||
if self.fuse_unk {
|
if self.fuse_unk {
|
||||||
let mut results = vec![];
|
let mut results = vec![];
|
||||||
let mut token = String::new();
|
let mut token = String::new();
|
||||||
for node in lattice.viterbi().iter() {
|
for node in lattice.viterbi().iter() {
|
||||||
let item = lattice.piece(&node.borrow());
|
let item = lattice.piece(&node.borrow());
|
||||||
if node.borrow().id == self.unk_id {
|
if node.borrow().id == self.unk_id.ok_or(UnigramError::MissingUnkId)? {
|
||||||
token.push_str(&item);
|
token.push_str(&item);
|
||||||
} else {
|
} else {
|
||||||
if !token.is_empty() {
|
if !token.is_empty() {
|
||||||
@@ -360,9 +354,9 @@ impl Unigram {
|
|||||||
if !token.is_empty() {
|
if !token.is_empty() {
|
||||||
results.push(token);
|
results.push(token);
|
||||||
}
|
}
|
||||||
results
|
Ok(results)
|
||||||
} else {
|
} else {
|
||||||
lattice.tokens()
|
Ok(lattice.tokens())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -416,21 +410,20 @@ impl Model for Unigram {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn tokenize(&self, sentence: &str) -> Result<Vec<Token>> {
|
fn tokenize(&self, sentence: &str) -> Result<Vec<Token>> {
|
||||||
let tokens = self.encode(sentence);
|
let str_tokens = self.encode(sentence)?;
|
||||||
let mut offset = 0;
|
let mut offset = 0;
|
||||||
Ok(tokens
|
let mut tokens = Vec::with_capacity(str_tokens.len());
|
||||||
.iter()
|
for string in str_tokens {
|
||||||
.map(|string| {
|
let id: u32 = match self.token_to_ids.get(&string) {
|
||||||
let id: u32 = match self.token_to_ids.get(string) {
|
|
||||||
Some(id) => *id,
|
Some(id) => *id,
|
||||||
None => self.unk_id as u32,
|
None => self.unk_id.ok_or(UnigramError::MissingUnkId)? as u32,
|
||||||
};
|
};
|
||||||
let len = string.len();
|
let len = string.len();
|
||||||
let offsets = (offset, offset + len);
|
let offsets = (offset, offset + len);
|
||||||
offset += len;
|
offset += len;
|
||||||
Token::new(id, string.to_string(), offsets)
|
tokens.push(Token::new(id, string, offsets));
|
||||||
})
|
}
|
||||||
.collect())
|
Ok(tokens)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn token_to_id(&self, token: &str) -> Option<u32> {
|
fn token_to_id(&self, token: &str) -> Option<u32> {
|
||||||
@@ -465,9 +458,9 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_populate_nodes_unk() {
|
fn test_populate_nodes_unk() {
|
||||||
let pieces = vec![("<unk>".to_string(), 0.0)];
|
let pieces = vec![("<unk>".to_string(), 0.0)];
|
||||||
let model = Unigram::from(pieces, 0).unwrap();
|
let model = Unigram::from(pieces, Some(0)).unwrap();
|
||||||
|
|
||||||
let mut lattice = Lattice::from("abc", 0, model.bos_id, model.eos_id);
|
let mut lattice = Lattice::from("abc", model.bos_id, model.eos_id);
|
||||||
model.populate_nodes(&mut lattice);
|
model.populate_nodes(&mut lattice);
|
||||||
|
|
||||||
assert_eq!(lattice.begin_nodes[0].len(), 1);
|
assert_eq!(lattice.begin_nodes[0].len(), 1);
|
||||||
@@ -490,9 +483,9 @@ mod tests {
|
|||||||
("ab".to_string(), 0.3),
|
("ab".to_string(), 0.3),
|
||||||
("bc".to_string(), 0.4),
|
("bc".to_string(), 0.4),
|
||||||
];
|
];
|
||||||
let model = Unigram::from(pieces, 0).unwrap();
|
let model = Unigram::from(pieces, Some(0)).unwrap();
|
||||||
|
|
||||||
let mut lattice = Lattice::from("abc", 0, model.bos_id, model.eos_id);
|
let mut lattice = Lattice::from("abc", model.bos_id, model.eos_id);
|
||||||
model.populate_nodes(&mut lattice);
|
model.populate_nodes(&mut lattice);
|
||||||
|
|
||||||
assert_eq!(lattice.begin_nodes[0].len(), 2); // a, ab
|
assert_eq!(lattice.begin_nodes[0].len(), 2); // a, ab
|
||||||
@@ -527,8 +520,8 @@ mod tests {
|
|||||||
("abcd".to_string(), 10.0),
|
("abcd".to_string(), 10.0),
|
||||||
];
|
];
|
||||||
|
|
||||||
let model = Unigram::from(sentencepieces, 0).unwrap();
|
let model = Unigram::from(sentencepieces, Some(0)).unwrap();
|
||||||
let result = model.encode("abcd");
|
let result = model.encode("abcd").unwrap();
|
||||||
assert_eq!(result, vec!["abcd"]);
|
assert_eq!(result, vec!["abcd"]);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -549,35 +542,41 @@ mod tests {
|
|||||||
("qr".to_string(), -0.5),
|
("qr".to_string(), -0.5),
|
||||||
];
|
];
|
||||||
|
|
||||||
let mut model = Unigram::from(sentencepieces, 0).unwrap();
|
let mut model = Unigram::from(sentencepieces, Some(0)).unwrap();
|
||||||
|
|
||||||
for is_optimized in &[true, false] {
|
for is_optimized in &[true, false] {
|
||||||
model.set_optimized(*is_optimized);
|
model.set_optimized(*is_optimized);
|
||||||
println!("IsOptimized {:?}", is_optimized);
|
println!("IsOptimized {:?}", is_optimized);
|
||||||
assert_eq!(model.encode("abc"), vec!["abc"]);
|
assert_eq!(model.encode("abc").unwrap(), vec!["abc"]);
|
||||||
assert_eq!(model.encode("AB"), vec!["AB"]);
|
assert_eq!(model.encode("AB").unwrap(), vec!["AB"]);
|
||||||
|
|
||||||
model.set_fuse_unk(false);
|
model.set_fuse_unk(false);
|
||||||
assert_eq!(model.encode("AB"), vec!["A", "B"]);
|
assert_eq!(model.encode("AB").unwrap(), vec!["A", "B"]);
|
||||||
model.set_fuse_unk(true);
|
model.set_fuse_unk(true);
|
||||||
assert_eq!(model.encode("AB"), vec!["AB"]);
|
assert_eq!(model.encode("AB").unwrap(), vec!["AB"]);
|
||||||
|
|
||||||
assert_eq!(model.encode("abcd"), vec!["ab", "cd"]);
|
assert_eq!(model.encode("abcd").unwrap(), vec!["ab", "cd"]);
|
||||||
assert_eq!(model.encode("abcc"), vec!["abc", "c"]);
|
assert_eq!(model.encode("abcc").unwrap(), vec!["abc", "c"]);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
model.encode("xabcabaabcdd"),
|
model.encode("xabcabaabcdd").unwrap(),
|
||||||
vec!["x", "abc", "ab", "a", "ab", "cd", "d"]
|
vec!["x", "abc", "ab", "a", "ab", "cd", "d"]
|
||||||
);
|
);
|
||||||
model.set_fuse_unk(false);
|
model.set_fuse_unk(false);
|
||||||
assert_eq!(model.encode("xyz東京"), vec!["x", "y", "z", "東", "京"]);
|
assert_eq!(
|
||||||
|
model.encode("xyz東京").unwrap(),
|
||||||
|
vec!["x", "y", "z", "東", "京"]
|
||||||
|
);
|
||||||
model.set_fuse_unk(true);
|
model.set_fuse_unk(true);
|
||||||
assert_eq!(model.encode("xyz東京"), vec!["xyz東京"]);
|
assert_eq!(model.encode("xyz東京").unwrap(), vec!["xyz東京"]);
|
||||||
|
|
||||||
// User encoded in original version
|
// User encoded in original version
|
||||||
assert_eq!(model.encode("ABC"), vec!["ABC"]);
|
assert_eq!(model.encode("ABC").unwrap(), vec!["ABC"]);
|
||||||
assert_eq!(model.encode("abABCcd"), vec!["ab", "ABC", "cd"]);
|
assert_eq!(model.encode("abABCcd").unwrap(), vec!["ab", "ABC", "cd"]);
|
||||||
assert_eq!(model.encode("ababcdabcdcd"), vec!["ab", "abcdabcd", "cd"]);
|
assert_eq!(
|
||||||
assert_eq!(model.encode("abqrcd"), vec!["ab", "q", "r", "cd"]);
|
model.encode("ababcdabcdcd").unwrap(),
|
||||||
|
vec!["ab", "abcdabcd", "cd"]
|
||||||
|
);
|
||||||
|
assert_eq!(model.encode("abqrcd").unwrap(), vec!["ab", "q", "r", "cd"]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -52,11 +52,9 @@ impl<'de> Visitor<'de> for UnigramVisitor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
match (vocab, unk_id) {
|
match (vocab, unk_id) {
|
||||||
(Some(vocab), Some(unk_id)) => Ok(Unigram::from(vocab, unk_id)
|
(Some(vocab), unk_id) => Ok(Unigram::from(vocab, unk_id)
|
||||||
.map_err(|err| Error::custom(&format!("Unable to load vocab {:?}", err)))?),
|
.map_err(|err| Error::custom(&format!("Unable to load vocab {:?}", err)))?),
|
||||||
(None, Some(_)) => Err(Error::custom("Missing vocab")),
|
(None, _) => Err(Error::custom("Missing vocab")),
|
||||||
(None, None) => Err(Error::custom("Missing vocab and unk_id")),
|
|
||||||
(Some(_), None) => Err(Error::custom("Missing unk_id")),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -68,7 +66,7 @@ mod test {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_serialization() {
|
fn test_serialization() {
|
||||||
let vocab = vec![("<unk>".to_string(), 0.0), ("a".to_string(), -0.5)];
|
let vocab = vec![("<unk>".to_string(), 0.0), ("a".to_string(), -0.5)];
|
||||||
let model = Unigram::from(vocab, 0).unwrap();
|
let model = Unigram::from(vocab, Some(0)).unwrap();
|
||||||
|
|
||||||
let data = serde_json::to_string(&model).unwrap();
|
let data = serde_json::to_string(&model).unwrap();
|
||||||
let reconstructed = serde_json::from_str(&data).unwrap();
|
let reconstructed = serde_json::from_str(&data).unwrap();
|
||||||
@@ -79,7 +77,18 @@ mod test {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_serialization_unk_id_not_zero() {
|
fn test_serialization_unk_id_not_zero() {
|
||||||
let vocab = vec![("a".to_string(), -0.5), ("<unk>".to_string(), 0.0)];
|
let vocab = vec![("a".to_string(), -0.5), ("<unk>".to_string(), 0.0)];
|
||||||
let model = Unigram::from(vocab, 1).unwrap();
|
let model = Unigram::from(vocab, Some(1)).unwrap();
|
||||||
|
|
||||||
|
let data = serde_json::to_string(&model).unwrap();
|
||||||
|
let reconstructed = serde_json::from_str(&data).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(model, reconstructed);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_serialization_no_unk_id() {
|
||||||
|
let vocab = vec![("a".to_string(), -0.5)];
|
||||||
|
let model = Unigram::from(vocab, None).unwrap();
|
||||||
|
|
||||||
let data = serde_json::to_string(&model).unwrap();
|
let data = serde_json::to_string(&model).unwrap();
|
||||||
let reconstructed = serde_json::from_str(&data).unwrap();
|
let reconstructed = serde_json::from_str(&data).unwrap();
|
||||||
|
|||||||
@@ -51,8 +51,8 @@ pub struct UnigramTrainer {
|
|||||||
#[builder(default = "HashSet::new()")]
|
#[builder(default = "HashSet::new()")]
|
||||||
initial_alphabet: HashSet<char>,
|
initial_alphabet: HashSet<char>,
|
||||||
|
|
||||||
#[builder(default = "String::from(\"<unk>\")")]
|
#[builder(default = "None")]
|
||||||
unk_token: String,
|
unk_token: Option<String>,
|
||||||
|
|
||||||
#[builder(default = "16")]
|
#[builder(default = "16")]
|
||||||
max_piece_length: usize,
|
max_piece_length: usize,
|
||||||
@@ -122,7 +122,33 @@ impl UnigramTrainer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
pieces.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap());
|
pieces.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap());
|
||||||
Unigram::from(pieces, 0)
|
|
||||||
|
// Insert the necessary tokens
|
||||||
|
let (unk_id, need_add_unk) = if let Some(ref unk) = self.unk_token {
|
||||||
|
let unk_id = self.special_tokens.iter().enumerate().find_map(|(i, t)| {
|
||||||
|
if t.content == *unk {
|
||||||
|
Some(i)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
});
|
||||||
|
match unk_id {
|
||||||
|
Some(id) => (Some(id), false),
|
||||||
|
None => (Some(0), true),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
(None, false)
|
||||||
|
};
|
||||||
|
let mut special_tokens = self
|
||||||
|
.special_tokens
|
||||||
|
.iter()
|
||||||
|
.map(|t| (t.content.clone(), 0.0))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
if need_add_unk {
|
||||||
|
special_tokens.insert(0, (self.unk_token.clone().unwrap(), 0.0));
|
||||||
|
}
|
||||||
|
|
||||||
|
Unigram::from(special_tokens.into_iter().chain(pieces).collect(), unk_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn required_chars(&self, word_counts: &[Sentence]) -> HashSet<String> {
|
fn required_chars(&self, word_counts: &[Sentence]) -> HashSet<String> {
|
||||||
@@ -230,7 +256,7 @@ impl UnigramTrainer {
|
|||||||
always_keep[id] = false;
|
always_keep[id] = false;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
let mut lattice = Lattice::from(token, 0, bos_id, eos_id);
|
let mut lattice = Lattice::from(token, bos_id, eos_id);
|
||||||
model.populate_nodes(&mut lattice);
|
model.populate_nodes(&mut lattice);
|
||||||
|
|
||||||
let nbests = lattice.nbest(2);
|
let nbests = lattice.nbest(2);
|
||||||
@@ -255,7 +281,7 @@ impl UnigramTrainer {
|
|||||||
let mut inverted: Vec<Vec<usize>> = vec![Vec::new(); pieces.len()];
|
let mut inverted: Vec<Vec<usize>> = vec![Vec::new(); pieces.len()];
|
||||||
// TODO reparallelize this
|
// TODO reparallelize this
|
||||||
for (i, (sentence, count)) in sentences.iter().enumerate() {
|
for (i, (sentence, count)) in sentences.iter().enumerate() {
|
||||||
let mut lattice = Lattice::from(sentence, 0, bos_id, eos_id);
|
let mut lattice = Lattice::from(sentence, bos_id, eos_id);
|
||||||
model.populate_nodes(&mut lattice);
|
model.populate_nodes(&mut lattice);
|
||||||
vsum += *count as f64;
|
vsum += *count as f64;
|
||||||
for node_ref in lattice.viterbi() {
|
for node_ref in lattice.viterbi() {
|
||||||
@@ -365,7 +391,7 @@ impl UnigramTrainer {
|
|||||||
|
|
||||||
// TODO reparallelize this.
|
// TODO reparallelize this.
|
||||||
for (string, freq) in sentences {
|
for (string, freq) in sentences {
|
||||||
let mut lattice = Lattice::from(string, model.unk_id, model.bos_id, model.eos_id);
|
let mut lattice = Lattice::from(string, model.bos_id, model.eos_id);
|
||||||
model.populate_nodes(&mut lattice);
|
model.populate_nodes(&mut lattice);
|
||||||
let z: f64 = lattice.populate_marginal(*freq as f64, &mut expected);
|
let z: f64 = lattice.populate_marginal(*freq as f64, &mut expected);
|
||||||
ntokens += lattice.viterbi().len() as u32;
|
ntokens += lattice.viterbi().len() as u32;
|
||||||
@@ -422,13 +448,7 @@ impl UnigramTrainer {
|
|||||||
self.update_progress(&progress, sentences.len(), "Suffix array seeds");
|
self.update_progress(&progress, sentences.len(), "Suffix array seeds");
|
||||||
let mut pieces: Vec<SentencePiece> =
|
let mut pieces: Vec<SentencePiece> =
|
||||||
Vec::with_capacity(self.vocab_size.try_into().unwrap());
|
Vec::with_capacity(self.vocab_size.try_into().unwrap());
|
||||||
// XXX: Make sure unk exists and are ids 0
|
|
||||||
pieces.push((self.unk_token.clone(), f64::NAN));
|
|
||||||
pieces.extend(
|
|
||||||
self.special_tokens
|
|
||||||
.iter()
|
|
||||||
.map(|tok| (tok.content.clone(), f64::NAN)),
|
|
||||||
);
|
|
||||||
pieces.extend(self.make_seed_sentence_pieces(&sentences, &progress)?);
|
pieces.extend(self.make_seed_sentence_pieces(&sentences, &progress)?);
|
||||||
self.finalize_progress(&progress, sentences.len());
|
self.finalize_progress(&progress, sentences.len());
|
||||||
|
|
||||||
@@ -452,7 +472,7 @@ impl UnigramTrainer {
|
|||||||
let expected_updates = expected_loops as usize * self.n_sub_iterations as usize;
|
let expected_updates = expected_loops as usize * self.n_sub_iterations as usize;
|
||||||
self.update_progress(&progress, expected_updates, "EM training");
|
self.update_progress(&progress, expected_updates, "EM training");
|
||||||
let required_chars = self.required_chars(&sentences);
|
let required_chars = self.required_chars(&sentences);
|
||||||
let mut model = Unigram::from(pieces.clone(), 0)?;
|
let mut model = Unigram::from(pieces.clone(), None)?;
|
||||||
loop {
|
loop {
|
||||||
// Sub-EM iteration.
|
// Sub-EM iteration.
|
||||||
for _iter in 0..self.n_sub_iterations {
|
for _iter in 0..self.n_sub_iterations {
|
||||||
@@ -461,7 +481,7 @@ impl UnigramTrainer {
|
|||||||
|
|
||||||
// Executes M step.
|
// Executes M step.
|
||||||
pieces = self.run_m_step(&pieces, &expected);
|
pieces = self.run_m_step(&pieces, &expected);
|
||||||
model = Unigram::from(pieces.clone(), 0)?;
|
model = Unigram::from(pieces.clone(), None)?;
|
||||||
|
|
||||||
// Useful comment for checking compatibility with spm
|
// Useful comment for checking compatibility with spm
|
||||||
debug!(
|
debug!(
|
||||||
@@ -485,7 +505,7 @@ impl UnigramTrainer {
|
|||||||
|
|
||||||
// Prunes pieces.
|
// Prunes pieces.
|
||||||
pieces = self.prune_sentence_pieces(&model, &pieces, &sentences);
|
pieces = self.prune_sentence_pieces(&model, &pieces, &sentences);
|
||||||
model = Unigram::from(pieces.clone(), 0)?;
|
model = Unigram::from(pieces.clone(), None)?;
|
||||||
}
|
}
|
||||||
self.finalize_progress(&progress, expected_updates);
|
self.finalize_progress(&progress, expected_updates);
|
||||||
|
|
||||||
@@ -598,6 +618,72 @@ mod tests {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_unk_token() {
|
||||||
|
// 1. Should add `unk_token` as first special token
|
||||||
|
let trainer = UnigramTrainerBuilder::default()
|
||||||
|
.show_progress(false)
|
||||||
|
.special_tokens(vec![
|
||||||
|
AddedToken::from("[SEP]", true),
|
||||||
|
AddedToken::from("[CLS]", true),
|
||||||
|
])
|
||||||
|
.unk_token(Some("[UNK]".into()))
|
||||||
|
.build()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let (unigram, _) = trainer
|
||||||
|
.train(HashMap::from_iter(vec![
|
||||||
|
("The".into(), 12),
|
||||||
|
("are".into(), 11),
|
||||||
|
]))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let mut pieces = unigram.iter();
|
||||||
|
assert_eq!(pieces.next(), Some(&("[UNK]".into(), 0.0)));
|
||||||
|
assert_eq!(pieces.next(), Some(&("[SEP]".into(), 0.0)));
|
||||||
|
assert_eq!(pieces.next(), Some(&("[CLS]".into(), 0.0)));
|
||||||
|
|
||||||
|
// 2. Let it where it is
|
||||||
|
let trainer = UnigramTrainerBuilder::default()
|
||||||
|
.show_progress(false)
|
||||||
|
.special_tokens(vec![
|
||||||
|
AddedToken::from("[SEP]", true),
|
||||||
|
AddedToken::from("[CLS]", true),
|
||||||
|
AddedToken::from("[UNK]", true),
|
||||||
|
])
|
||||||
|
.unk_token(Some("[UNK]".into()))
|
||||||
|
.build()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let (unigram, _) = trainer
|
||||||
|
.train(HashMap::from_iter(vec![
|
||||||
|
("The".into(), 12),
|
||||||
|
("are".into(), 11),
|
||||||
|
]))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let mut pieces = unigram.iter();
|
||||||
|
assert_eq!(pieces.next(), Some(&("[SEP]".into(), 0.0)));
|
||||||
|
assert_eq!(pieces.next(), Some(&("[CLS]".into(), 0.0)));
|
||||||
|
assert_eq!(pieces.next(), Some(&("[UNK]".into(), 0.0)));
|
||||||
|
|
||||||
|
// 3. Don't put it there if not needed
|
||||||
|
let trainer = UnigramTrainerBuilder::default()
|
||||||
|
.show_progress(false)
|
||||||
|
.build()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let (unigram, _) = trainer
|
||||||
|
.train(HashMap::from_iter(vec![
|
||||||
|
("The".into(), 12),
|
||||||
|
("are".into(), 11),
|
||||||
|
]))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let mut pieces = unigram.iter();
|
||||||
|
assert_eq!(pieces.next().unwrap().0, "e".to_string());
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_special_tokens() {
|
fn test_special_tokens() {
|
||||||
let trainer = UnigramTrainerBuilder::default()
|
let trainer = UnigramTrainerBuilder::default()
|
||||||
@@ -617,7 +703,6 @@ mod tests {
|
|||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let mut pieces = unigram.iter();
|
let mut pieces = unigram.iter();
|
||||||
assert_eq!(pieces.next(), Some(&("<unk>".into(), 0.0)));
|
|
||||||
assert_eq!(pieces.next(), Some(&("[SEP]".into(), 0.0)));
|
assert_eq!(pieces.next(), Some(&("[SEP]".into(), 0.0)));
|
||||||
assert_eq!(pieces.next(), Some(&("[CLS]".into(), 0.0)));
|
assert_eq!(pieces.next(), Some(&("[CLS]".into(), 0.0)));
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -52,10 +52,11 @@ fn test_train_unigram_from_file() {
|
|||||||
|
|
||||||
let trainer = UnigramTrainer::builder()
|
let trainer = UnigramTrainer::builder()
|
||||||
.show_progress(false)
|
.show_progress(false)
|
||||||
|
.unk_token(Some("<unk>".into()))
|
||||||
.build()
|
.build()
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let (model, _) = trainer.train(word_counts).unwrap();
|
let (model, _) = trainer.train(word_counts).unwrap();
|
||||||
assert_eq!(model.get_vocab_size(), 719);
|
assert_eq!(model.get_vocab_size(), 717);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(not(debug_assertions))]
|
#[cfg(not(debug_assertions))]
|
||||||
|
|||||||
Reference in New Issue
Block a user