mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-23 00:35:35 +00:00
Fixing the vocab size of the trained Unigram model (#952)
* Fixing the vocab size of the trained Unigram model * add test for the vocab size of the trained Unigram model * Revert "add test for the vocab size of the trained Unigram model" This reverts commit fb8955c831b357d1037548ceaa8789734d544646. * Fixing the vocab size of the trained Unigram model * format codes * get the position of vocab-size calculation out of loop
This commit is contained in:
@ -238,6 +238,28 @@ class TestUnigram:
|
||||
"[SEP]",
|
||||
]
|
||||
|
||||
tokenizer = Tokenizer(models.Unigram())
|
||||
trainer = trainers.UnigramTrainer(
|
||||
show_progress=False,
|
||||
special_tokens=["[PAD]", "[SEP]", "[CLS]"],
|
||||
unk_token="[UNK]",
|
||||
vocab_size=100,
|
||||
)
|
||||
tokenizer.train([filename], trainer=trainer)
|
||||
|
||||
assert tokenizer.get_vocab_size() == 100
|
||||
|
||||
tokenizer = Tokenizer(models.Unigram())
|
||||
trainer = trainers.UnigramTrainer(
|
||||
show_progress=False,
|
||||
special_tokens=["[PAD]", "[SEP]", "[CLS]", "[UNK]"],
|
||||
unk_token="[UNK]",
|
||||
vocab_size=100,
|
||||
)
|
||||
tokenizer.train([filename], trainer=trainer)
|
||||
|
||||
assert tokenizer.get_vocab_size() == 100
|
||||
|
||||
def test_cannot_train_different_model(self):
|
||||
tokenizer = Tokenizer(models.BPE())
|
||||
trainer = trainers.UnigramTrainer(show_progress=False)
|
||||
|
@ -126,19 +126,7 @@ impl UnigramTrainer {
|
||||
min_score_penalty += min_score_penalty_delta;
|
||||
}
|
||||
}
|
||||
for (token, score) in model.iter() {
|
||||
if inserted.contains::<str>(token) {
|
||||
continue;
|
||||
}
|
||||
inserted.insert(token.to_string());
|
||||
pieces.push((token.to_string(), if score.is_nan() { 0.0 } else { *score }));
|
||||
if pieces.len() == self.vocab_size as usize {
|
||||
break;
|
||||
}
|
||||
}
|
||||
pieces.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap());
|
||||
|
||||
// Insert the necessary tokens
|
||||
let (unk_id, need_add_unk) = if let Some(ref unk) = self.unk_token {
|
||||
let unk_id = self.special_tokens.iter().enumerate().find_map(|(i, t)| {
|
||||
if t.content == *unk {
|
||||
@ -154,6 +142,26 @@ impl UnigramTrainer {
|
||||
} else {
|
||||
(None, false)
|
||||
};
|
||||
|
||||
let vocab_size_without_special_tokens = if need_add_unk {
|
||||
self.vocab_size as usize - self.special_tokens.len() - 1
|
||||
} else {
|
||||
self.vocab_size as usize - self.special_tokens.len()
|
||||
};
|
||||
for (token, score) in model.iter() {
|
||||
if inserted.contains::<str>(token) {
|
||||
continue;
|
||||
}
|
||||
inserted.insert(token.to_string());
|
||||
pieces.push((token.to_string(), if score.is_nan() { 0.0 } else { *score }));
|
||||
|
||||
if pieces.len() == vocab_size_without_special_tokens {
|
||||
break;
|
||||
}
|
||||
}
|
||||
pieces.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap());
|
||||
|
||||
// Insert the necessary tokens
|
||||
let mut special_tokens = self
|
||||
.special_tokens
|
||||
.iter()
|
||||
|
Reference in New Issue
Block a user