Add BpeTrainerBuilder

This commit is contained in:
Anthony MOI
2020-01-03 16:09:47 -05:00
parent 02a89bb07f
commit 1bfe9fd0a7

View File

@ -5,6 +5,104 @@ use crate::tokenizer::{Model, Result, Trainer};
use indicatif::{ProgressBar, ProgressStyle};
use std::collections::{HashMap, HashSet};
#[derive(Default)]
pub struct Config {
min_frequency: Option<u32>,
vocab_size: Option<usize>,
show_progress: Option<bool>,
special_tokens: Option<Vec<String>>,
limit_alphabet: Option<usize>,
initial_alphabet: Option<HashSet<char>>,
continuing_subword_prefix: Option<String>,
end_of_word_suffix: Option<String>,
}
#[derive(Default)]
pub struct BpeTrainerBuilder {
config: Config,
}
impl BpeTrainerBuilder {
/// Constructs a new `BpeTrainerBuilder`
pub fn new() -> Self {
Self::default()
}
/// Set the expected minimum frequency
pub fn min_frequency(mut self, frequency: u32) -> Self {
self.config.min_frequency = Some(frequency);
self
}
/// Set the vocabulary size
pub fn vocab_size(mut self, size: usize) -> Self {
self.config.vocab_size = Some(size);
self
}
/// Set whether to show progress
pub fn show_progress(mut self, show: bool) -> Self {
self.config.show_progress = Some(show);
self
}
/// Set the special tokens
pub fn special_tokens(mut self, tokens: Vec<String>) -> Self {
self.config.special_tokens = Some(tokens);
self
}
/// Set whether to limit the alphabet
pub fn limit_alphabet(mut self, limit: usize) -> Self {
self.config.limit_alphabet = Some(limit);
self
}
/// Set the initial alphabet
pub fn initial_alphabet(mut self, alphabet: HashSet<char>) -> Self {
self.config.initial_alphabet = Some(alphabet);
self
}
/// Set the continuing_subword_prefix
pub fn continuing_subword_prefix(mut self, prefix: String) -> Self {
self.config.continuing_subword_prefix = Some(prefix);
self
}
/// Set the end_of_word_suffix
pub fn end_of_word_suffix(mut self, suffix: String) -> Self {
self.config.end_of_word_suffix = Some(suffix);
self
}
/// Constructs the final BpeTrainer
pub fn build(self) -> Result<BpeTrainer> {
let mut trainer = BpeTrainer::default();
if let Some(freq) = self.config.min_frequency {
trainer.min_frequency = freq;
}
if let Some(vocab_size) = self.config.vocab_size {
trainer.vocab_size = vocab_size;
}
if let Some(show) = self.config.show_progress {
trainer.show_progress = show;
}
if let Some(special_tokens) = self.config.special_tokens {
trainer.special_tokens = special_tokens;
}
if let Some(alphabet) = self.config.initial_alphabet {
trainer.initial_alphabet = alphabet;
}
trainer.limit_alphabet = self.config.limit_alphabet;
trainer.continuing_subword_prefix = self.config.continuing_subword_prefix;
trainer.end_of_word_suffix = self.config.end_of_word_suffix;
Ok(trainer)
}
}
/// In charge of training a BPE model from a mapping of words to word counts.
///
/// # Examples
@ -23,22 +121,22 @@ use std::collections::{HashMap, HashSet};
/// ```
pub struct BpeTrainer {
/// The minimum frequency a pair must have to produce a merge operation
pub min_frequency: u32,
min_frequency: u32,
/// The target vocabulary size
pub vocab_size: usize,
vocab_size: usize,
/// Whether to show progress while training
pub show_progress: bool,
show_progress: bool,
/// A list of special tokens that the model should know of
pub special_tokens: Vec<String>,
special_tokens: Vec<String>,
/// Whether to limit the number of initial tokens that can be kept before computing merges
pub limit_alphabet: Option<usize>,
limit_alphabet: Option<usize>,
/// The initial alphabet we want absolutely to include. This allows to cover
/// some characters that are not necessarily in the training set
pub initial_alphabet: HashSet<char>,
initial_alphabet: HashSet<char>,
/// An optional prefix to use on any subword that exist only behind another one
pub continuing_subword_prefix: Option<String>,
continuing_subword_prefix: Option<String>,
/// An optional suffix to caracterize and end-of-word subword
pub end_of_word_suffix: Option<String>,
end_of_word_suffix: Option<String>,
}
impl Default for BpeTrainer {
@ -65,6 +163,10 @@ impl BpeTrainer {
}
}
pub fn builder() -> BpeTrainerBuilder {
BpeTrainerBuilder::new()
}
/// Setup a progress bar if asked to show progress
fn setup_progress(&self) -> Option<ProgressBar> {
if self.show_progress {