mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
simplify initialization of BpeTrainer
This commit is contained in:
@ -20,18 +20,18 @@ impl BpeTrainer {
|
||||
#[staticmethod]
|
||||
#[args(kwargs = "**")]
|
||||
pub fn new(kwargs: Option<&PyDict>) -> PyResult<Trainer> {
|
||||
let mut config: tk::models::bpe::BpeTrainerConfig = Default::default();
|
||||
let mut trainer = tk::models::bpe::BpeTrainer::default();
|
||||
if let Some(kwargs) = kwargs {
|
||||
for (key, val) in kwargs {
|
||||
let key: &str = key.extract()?;
|
||||
match key {
|
||||
"vocab_size" => {
|
||||
let size: usize = val.extract()?;
|
||||
config.set_vocab_size(size);
|
||||
trainer.vocab_size = size;
|
||||
}
|
||||
"min_frequency" => {
|
||||
let freq: u32 = val.extract()?;
|
||||
config.set_min_frequency(freq);
|
||||
trainer.min_frequency = freq;
|
||||
}
|
||||
_ => println!("Ignored unknown kwargs option {}", key),
|
||||
};
|
||||
@ -39,7 +39,7 @@ impl BpeTrainer {
|
||||
}
|
||||
|
||||
Ok(Trainer {
|
||||
trainer: Container::Owned(Box::new(tk::models::bpe::BpeTrainer::new(config))),
|
||||
trainer: Container::Owned(Box::new(trainer)),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -7,34 +7,6 @@ use std::{
|
||||
time::Instant,
|
||||
};
|
||||
|
||||
pub struct BpeTrainerConfig {
|
||||
min_frequency: u32,
|
||||
vocab_size: usize,
|
||||
}
|
||||
|
||||
impl BpeTrainerConfig {
|
||||
pub fn new(min_frequency: u32, vocab_size: usize) -> Self {
|
||||
BpeTrainerConfig {
|
||||
min_frequency,
|
||||
vocab_size,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_vocab_size(&mut self, value: usize) {
|
||||
self.vocab_size = value;
|
||||
}
|
||||
|
||||
pub fn set_min_frequency(&mut self, value: u32) {
|
||||
self.min_frequency = value;
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for BpeTrainerConfig {
|
||||
fn default() -> Self {
|
||||
BpeTrainerConfig::new(0, 30000)
|
||||
}
|
||||
}
|
||||
|
||||
/// In charge of training a BPE model from a mapping of words to word counts.
|
||||
///
|
||||
/// # Examples
|
||||
@ -51,15 +23,26 @@ impl Default for BpeTrainerConfig {
|
||||
/// let trainer = BpeTrainer::default();
|
||||
/// let model = trainer.train(word_counts);
|
||||
/// ```
|
||||
#[derive(Default)]
|
||||
pub struct BpeTrainer {
|
||||
// Training parameters
|
||||
config: BpeTrainerConfig,
|
||||
pub min_frequency: u32,
|
||||
pub vocab_size: usize,
|
||||
}
|
||||
|
||||
impl Default for BpeTrainer {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
min_frequency: 0,
|
||||
vocab_size: 30000,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl BpeTrainer {
|
||||
pub fn new(config: BpeTrainerConfig) -> Self {
|
||||
BpeTrainer { config }
|
||||
pub fn new(min_frequency: u32, vocab_size: usize) -> Self {
|
||||
Self {
|
||||
min_frequency,
|
||||
vocab_size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -134,7 +117,7 @@ impl Trainer for BpeTrainer {
|
||||
let timer = Instant::now();
|
||||
loop {
|
||||
// Stop as soon as we have a big enough vocabulary
|
||||
if word_to_id.len() >= self.config.vocab_size {
|
||||
if word_to_id.len() >= self.vocab_size {
|
||||
break;
|
||||
}
|
||||
|
||||
@ -150,7 +133,7 @@ impl Trainer for BpeTrainer {
|
||||
}
|
||||
}
|
||||
// Stop if we reached the minimum frequency
|
||||
if best_count < 1 || self.config.min_frequency > best_count as u32 {
|
||||
if best_count < 1 || self.min_frequency > best_count as u32 {
|
||||
break;
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user