simplify initialization of BpeTrainer

This commit is contained in:
epwalsh
2019-12-23 14:38:54 -08:00
committed by MOI Anthony
parent fab1d4cabc
commit c0ed873c4d
2 changed files with 22 additions and 39 deletions

View File

@ -20,18 +20,18 @@ impl BpeTrainer {
#[staticmethod]
#[args(kwargs = "**")]
pub fn new(kwargs: Option<&PyDict>) -> PyResult<Trainer> {
let mut config: tk::models::bpe::BpeTrainerConfig = Default::default();
let mut trainer = tk::models::bpe::BpeTrainer::default();
if let Some(kwargs) = kwargs {
for (key, val) in kwargs {
let key: &str = key.extract()?;
match key {
"vocab_size" => {
let size: usize = val.extract()?;
config.set_vocab_size(size);
trainer.vocab_size = size;
}
"min_frequency" => {
let freq: u32 = val.extract()?;
config.set_min_frequency(freq);
trainer.min_frequency = freq;
}
_ => println!("Ignored unknown kwargs option {}", key),
};
@ -39,7 +39,7 @@ impl BpeTrainer {
}
Ok(Trainer {
trainer: Container::Owned(Box::new(tk::models::bpe::BpeTrainer::new(config))),
trainer: Container::Owned(Box::new(trainer)),
})
}
}

View File

@ -7,34 +7,6 @@ use std::{
time::Instant,
};
pub struct BpeTrainerConfig {
min_frequency: u32,
vocab_size: usize,
}
impl BpeTrainerConfig {
pub fn new(min_frequency: u32, vocab_size: usize) -> Self {
BpeTrainerConfig {
min_frequency,
vocab_size,
}
}
pub fn set_vocab_size(&mut self, value: usize) {
self.vocab_size = value;
}
pub fn set_min_frequency(&mut self, value: u32) {
self.min_frequency = value;
}
}
impl Default for BpeTrainerConfig {
fn default() -> Self {
BpeTrainerConfig::new(0, 30000)
}
}
/// In charge of training a BPE model from a mapping of words to word counts.
///
/// # Examples
@ -51,15 +23,26 @@ impl Default for BpeTrainerConfig {
/// let trainer = BpeTrainer::default();
/// let model = trainer.train(word_counts);
/// ```
#[derive(Default)]
pub struct BpeTrainer {
// Training parameters
config: BpeTrainerConfig,
pub min_frequency: u32,
pub vocab_size: usize,
}
impl Default for BpeTrainer {
fn default() -> Self {
Self {
min_frequency: 0,
vocab_size: 30000,
}
}
}
impl BpeTrainer {
pub fn new(config: BpeTrainerConfig) -> Self {
BpeTrainer { config }
pub fn new(min_frequency: u32, vocab_size: usize) -> Self {
Self {
min_frequency,
vocab_size,
}
}
}
@ -134,7 +117,7 @@ impl Trainer for BpeTrainer {
let timer = Instant::now();
loop {
// Stop as soon as we have a big enough vocabulary
if word_to_id.len() >= self.config.vocab_size {
if word_to_id.len() >= self.vocab_size {
break;
}
@ -150,7 +133,7 @@ impl Trainer for BpeTrainer {
}
}
// Stop if we reached the minimum frequency
if best_count < 1 || self.config.min_frequency > best_count as u32 {
if best_count < 1 || self.min_frequency > best_count as u32 {
break;
}