Allow initial_alphabet on UnigramTrainer

This commit is contained in:
Anthony MOI
2020-10-22 14:32:40 -04:00
committed by Anthony MOI
parent f7c61c267a
commit 1a6f4b5204
3 changed files with 44 additions and 4 deletions

View File

@ -46,7 +46,7 @@ class BpeTrainer(Trainer):
initial_alphabet: List[str]:
A list of characters to include in the initial alphabet, even
if not seen in the training dataset.
If the strings contains more than one character, only the first one
If the strings contain more than one character, only the first one
is kept.
continuing_subword_prefix: Optional[str]:
@ -98,7 +98,7 @@ class WordPieceTrainer(Trainer):
initial_alphabet: List[str]:
A list of characters to include in the initial alphabet, even
if not seen in the training dataset.
If the strings contains more than one character, only the first one
If the strings contain more than one character, only the first one
is kept.
continuing_subword_prefix: Optional[str]:
@ -136,6 +136,12 @@ class UnigramTrainer(Trainer):
special_tokens: List[Union[str, AddedToken]]:
A list of special tokens the model should know of.
initial_alphabet: List[str]:
A list of characters to include in the initial alphabet, even
if not seen in the training dataset.
If the strings contain more than one character, only the first one
is kept.
Returns:
Trainer
"""

View File

@ -193,6 +193,17 @@ impl PyUnigramTrainer {
"unk_token" => builder.unk_token(val.extract()?),
"max_piece_length" => builder.max_piece_length(val.extract()?),
"seed_size" => builder.seed_size(val.extract()?),
"initial_alphabet" => {
let alphabet: Vec<String> = val.extract()?;
builder.initial_alphabet(
alphabet
.into_iter()
.map(|s| s.chars().next())
.filter(|c| c.is_some())
.map(|c| c.unwrap())
.collect(),
)
}
"special_tokens" => builder.special_tokens(
val.cast_as::<PyList>()?
.into_iter()

View File

@ -48,6 +48,8 @@ pub struct UnigramTrainer {
shrinking_factor: f64,
#[builder(default = "vec![]")]
special_tokens: Vec<AddedToken>,
#[builder(default = "HashSet::new()")]
initial_alphabet: HashSet<char>,
#[builder(default = "String::from(\"<unk>\")")]
unk_token: String,
@ -125,8 +127,8 @@ impl UnigramTrainer {
fn required_chars(&self, word_counts: &[Sentence]) -> HashSet<String> {
word_counts
.iter()
.map(|(s, _count)| s.chars())
.flatten()
.flat_map(|(s, _count)| s.chars())
.chain(self.initial_alphabet.iter().copied())
.map(|c| c.to_string())
.collect()
}
@ -517,6 +519,7 @@ impl Trainer for UnigramTrainer {
mod tests {
use super::*;
use assert_approx_eq::assert_approx_eq;
use std::iter::FromIterator;
#[test]
fn test_unigram_chars() {
@ -569,6 +572,26 @@ mod tests {
}
}
#[test]
fn test_initial_alphabet() {
let trainer = UnigramTrainerBuilder::default()
.show_progress(false)
.initial_alphabet(HashSet::from_iter(vec!['a', 'b', 'c', 'd', 'e', 'f']))
.build()
.unwrap();
let sentences = vec![("こんにちは友達".to_string(), 1)];
let required_chars = trainer.required_chars(&sentences);
assert_eq!(
required_chars,
HashSet::from_iter(
vec!["", "", "", "", "", "", "", "a", "b", "c", "d", "e", "f"]
.into_iter()
.map(|s| s.to_owned())
)
);
}
#[test]
fn test_to_log_prob() {
let mut a = vec![("".to_string(), 1.0), ("".to_string(), 2.0)];