mirror of
https://github.com/mii443/tokenizers.git
synced 2025-09-01 14:59:20 +00:00
Allow initial_alphabet on UnigramTrainer
This commit is contained in:
@ -46,7 +46,7 @@ class BpeTrainer(Trainer):
|
||||
initial_alphabet: List[str]:
|
||||
A list of characters to include in the initial alphabet, even
|
||||
if not seen in the training dataset.
|
||||
If the strings contains more than one character, only the first one
|
||||
If the strings contain more than one character, only the first one
|
||||
is kept.
|
||||
|
||||
continuing_subword_prefix: Optional[str]:
|
||||
@ -98,7 +98,7 @@ class WordPieceTrainer(Trainer):
|
||||
initial_alphabet: List[str]:
|
||||
A list of characters to include in the initial alphabet, even
|
||||
if not seen in the training dataset.
|
||||
If the strings contains more than one character, only the first one
|
||||
If the strings contain more than one character, only the first one
|
||||
is kept.
|
||||
|
||||
continuing_subword_prefix: Optional[str]:
|
||||
@ -136,6 +136,12 @@ class UnigramTrainer(Trainer):
|
||||
special_tokens: List[Union[str, AddedToken]]:
|
||||
A list of special tokens the model should know of.
|
||||
|
||||
initial_alphabet: List[str]:
|
||||
A list of characters to include in the initial alphabet, even
|
||||
if not seen in the training dataset.
|
||||
If the strings contain more than one character, only the first one
|
||||
is kept.
|
||||
|
||||
Returns:
|
||||
Trainer
|
||||
"""
|
||||
|
@ -193,6 +193,17 @@ impl PyUnigramTrainer {
|
||||
"unk_token" => builder.unk_token(val.extract()?),
|
||||
"max_piece_length" => builder.max_piece_length(val.extract()?),
|
||||
"seed_size" => builder.seed_size(val.extract()?),
|
||||
"initial_alphabet" => {
|
||||
let alphabet: Vec<String> = val.extract()?;
|
||||
builder.initial_alphabet(
|
||||
alphabet
|
||||
.into_iter()
|
||||
.map(|s| s.chars().next())
|
||||
.filter(|c| c.is_some())
|
||||
.map(|c| c.unwrap())
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
"special_tokens" => builder.special_tokens(
|
||||
val.cast_as::<PyList>()?
|
||||
.into_iter()
|
||||
|
@ -48,6 +48,8 @@ pub struct UnigramTrainer {
|
||||
shrinking_factor: f64,
|
||||
#[builder(default = "vec![]")]
|
||||
special_tokens: Vec<AddedToken>,
|
||||
#[builder(default = "HashSet::new()")]
|
||||
initial_alphabet: HashSet<char>,
|
||||
|
||||
#[builder(default = "String::from(\"<unk>\")")]
|
||||
unk_token: String,
|
||||
@ -125,8 +127,8 @@ impl UnigramTrainer {
|
||||
fn required_chars(&self, word_counts: &[Sentence]) -> HashSet<String> {
|
||||
word_counts
|
||||
.iter()
|
||||
.map(|(s, _count)| s.chars())
|
||||
.flatten()
|
||||
.flat_map(|(s, _count)| s.chars())
|
||||
.chain(self.initial_alphabet.iter().copied())
|
||||
.map(|c| c.to_string())
|
||||
.collect()
|
||||
}
|
||||
@ -517,6 +519,7 @@ impl Trainer for UnigramTrainer {
|
||||
mod tests {
|
||||
use super::*;
|
||||
use assert_approx_eq::assert_approx_eq;
|
||||
use std::iter::FromIterator;
|
||||
|
||||
#[test]
|
||||
fn test_unigram_chars() {
|
||||
@ -569,6 +572,26 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_initial_alphabet() {
|
||||
let trainer = UnigramTrainerBuilder::default()
|
||||
.show_progress(false)
|
||||
.initial_alphabet(HashSet::from_iter(vec!['a', 'b', 'c', 'd', 'e', 'f']))
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let sentences = vec![("こんにちは友達".to_string(), 1)];
|
||||
let required_chars = trainer.required_chars(&sentences);
|
||||
assert_eq!(
|
||||
required_chars,
|
||||
HashSet::from_iter(
|
||||
vec!["こ", "ん", "に", "ち", "は", "友", "達", "a", "b", "c", "d", "e", "f"]
|
||||
.into_iter()
|
||||
.map(|s| s.to_owned())
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_to_log_prob() {
|
||||
let mut a = vec![("".to_string(), 1.0), ("".to_string(), 2.0)];
|
||||
|
Reference in New Issue
Block a user