Fixing the vocab size of the trained Unigram model (#952)

* Fixing the vocab size of the trained Unigram model

* add test for the vocab size of the trained Unigram model

* Revert "add test for the vocab size of the trained Unigram model"

This reverts commit fb8955c831b357d1037548ceaa8789734d544646.

* Fixing the vocab size of the trained Unigram model

* format codes

* get the position of vocab-size calculation out of loop
This commit is contained in:
Kaito Sugimoto
2022-03-19 02:13:17 +09:00
committed by GitHub
parent daa4dd2288
commit 1bb9884f45
2 changed files with 42 additions and 12 deletions

View File

@ -238,6 +238,28 @@ class TestUnigram:
"[SEP]",
]
tokenizer = Tokenizer(models.Unigram())
trainer = trainers.UnigramTrainer(
show_progress=False,
special_tokens=["[PAD]", "[SEP]", "[CLS]"],
unk_token="[UNK]",
vocab_size=100,
)
tokenizer.train([filename], trainer=trainer)
assert tokenizer.get_vocab_size() == 100
tokenizer = Tokenizer(models.Unigram())
trainer = trainers.UnigramTrainer(
show_progress=False,
special_tokens=["[PAD]", "[SEP]", "[CLS]", "[UNK]"],
unk_token="[UNK]",
vocab_size=100,
)
tokenizer.train([filename], trainer=trainer)
assert tokenizer.get_vocab_size() == 100
def test_cannot_train_different_model(self):
tokenizer = Tokenizer(models.BPE())
trainer = trainers.UnigramTrainer(show_progress=False)

View File

@ -126,19 +126,7 @@ impl UnigramTrainer {
min_score_penalty += min_score_penalty_delta;
}
}
for (token, score) in model.iter() {
if inserted.contains::<str>(token) {
continue;
}
inserted.insert(token.to_string());
pieces.push((token.to_string(), if score.is_nan() { 0.0 } else { *score }));
if pieces.len() == self.vocab_size as usize {
break;
}
}
pieces.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap());
// Insert the necessary tokens
let (unk_id, need_add_unk) = if let Some(ref unk) = self.unk_token {
let unk_id = self.special_tokens.iter().enumerate().find_map(|(i, t)| {
if t.content == *unk {
@ -154,6 +142,26 @@ impl UnigramTrainer {
} else {
(None, false)
};
let vocab_size_without_special_tokens = if need_add_unk {
self.vocab_size as usize - self.special_tokens.len() - 1
} else {
self.vocab_size as usize - self.special_tokens.len()
};
for (token, score) in model.iter() {
if inserted.contains::<str>(token) {
continue;
}
inserted.insert(token.to_string());
pieces.push((token.to_string(), if score.is_nan() { 0.0 } else { *score }));
if pieces.len() == vocab_size_without_special_tokens {
break;
}
}
pieces.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap());
// Insert the necessary tokens
let mut special_tokens = self
.special_tokens
.iter()