mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Add BpeTrainerBuilder
This commit is contained in:
@ -5,6 +5,104 @@ use crate::tokenizer::{Model, Result, Trainer};
|
||||
use indicatif::{ProgressBar, ProgressStyle};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct Config {
|
||||
min_frequency: Option<u32>,
|
||||
vocab_size: Option<usize>,
|
||||
show_progress: Option<bool>,
|
||||
special_tokens: Option<Vec<String>>,
|
||||
limit_alphabet: Option<usize>,
|
||||
initial_alphabet: Option<HashSet<char>>,
|
||||
continuing_subword_prefix: Option<String>,
|
||||
end_of_word_suffix: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct BpeTrainerBuilder {
|
||||
config: Config,
|
||||
}
|
||||
|
||||
impl BpeTrainerBuilder {
|
||||
/// Constructs a new `BpeTrainerBuilder`
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Set the expected minimum frequency
|
||||
pub fn min_frequency(mut self, frequency: u32) -> Self {
|
||||
self.config.min_frequency = Some(frequency);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the vocabulary size
|
||||
pub fn vocab_size(mut self, size: usize) -> Self {
|
||||
self.config.vocab_size = Some(size);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set whether to show progress
|
||||
pub fn show_progress(mut self, show: bool) -> Self {
|
||||
self.config.show_progress = Some(show);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the special tokens
|
||||
pub fn special_tokens(mut self, tokens: Vec<String>) -> Self {
|
||||
self.config.special_tokens = Some(tokens);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set whether to limit the alphabet
|
||||
pub fn limit_alphabet(mut self, limit: usize) -> Self {
|
||||
self.config.limit_alphabet = Some(limit);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the initial alphabet
|
||||
pub fn initial_alphabet(mut self, alphabet: HashSet<char>) -> Self {
|
||||
self.config.initial_alphabet = Some(alphabet);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the continuing_subword_prefix
|
||||
pub fn continuing_subword_prefix(mut self, prefix: String) -> Self {
|
||||
self.config.continuing_subword_prefix = Some(prefix);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the end_of_word_suffix
|
||||
pub fn end_of_word_suffix(mut self, suffix: String) -> Self {
|
||||
self.config.end_of_word_suffix = Some(suffix);
|
||||
self
|
||||
}
|
||||
|
||||
/// Constructs the final BpeTrainer
|
||||
pub fn build(self) -> Result<BpeTrainer> {
|
||||
let mut trainer = BpeTrainer::default();
|
||||
|
||||
if let Some(freq) = self.config.min_frequency {
|
||||
trainer.min_frequency = freq;
|
||||
}
|
||||
if let Some(vocab_size) = self.config.vocab_size {
|
||||
trainer.vocab_size = vocab_size;
|
||||
}
|
||||
if let Some(show) = self.config.show_progress {
|
||||
trainer.show_progress = show;
|
||||
}
|
||||
if let Some(special_tokens) = self.config.special_tokens {
|
||||
trainer.special_tokens = special_tokens;
|
||||
}
|
||||
if let Some(alphabet) = self.config.initial_alphabet {
|
||||
trainer.initial_alphabet = alphabet;
|
||||
}
|
||||
trainer.limit_alphabet = self.config.limit_alphabet;
|
||||
trainer.continuing_subword_prefix = self.config.continuing_subword_prefix;
|
||||
trainer.end_of_word_suffix = self.config.end_of_word_suffix;
|
||||
|
||||
Ok(trainer)
|
||||
}
|
||||
}
|
||||
|
||||
/// In charge of training a BPE model from a mapping of words to word counts.
|
||||
///
|
||||
/// # Examples
|
||||
@ -23,22 +121,22 @@ use std::collections::{HashMap, HashSet};
|
||||
/// ```
|
||||
pub struct BpeTrainer {
|
||||
/// The minimum frequency a pair must have to produce a merge operation
|
||||
pub min_frequency: u32,
|
||||
min_frequency: u32,
|
||||
/// The target vocabulary size
|
||||
pub vocab_size: usize,
|
||||
vocab_size: usize,
|
||||
/// Whether to show progress while training
|
||||
pub show_progress: bool,
|
||||
show_progress: bool,
|
||||
/// A list of special tokens that the model should know of
|
||||
pub special_tokens: Vec<String>,
|
||||
special_tokens: Vec<String>,
|
||||
/// Whether to limit the number of initial tokens that can be kept before computing merges
|
||||
pub limit_alphabet: Option<usize>,
|
||||
limit_alphabet: Option<usize>,
|
||||
/// The initial alphabet we want absolutely to include. This allows to cover
|
||||
/// some characters that are not necessarily in the training set
|
||||
pub initial_alphabet: HashSet<char>,
|
||||
initial_alphabet: HashSet<char>,
|
||||
/// An optional prefix to use on any subword that exist only behind another one
|
||||
pub continuing_subword_prefix: Option<String>,
|
||||
continuing_subword_prefix: Option<String>,
|
||||
/// An optional suffix to caracterize and end-of-word subword
|
||||
pub end_of_word_suffix: Option<String>,
|
||||
end_of_word_suffix: Option<String>,
|
||||
}
|
||||
|
||||
impl Default for BpeTrainer {
|
||||
@ -65,6 +163,10 @@ impl BpeTrainer {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn builder() -> BpeTrainerBuilder {
|
||||
BpeTrainerBuilder::new()
|
||||
}
|
||||
|
||||
/// Setup a progress bar if asked to show progress
|
||||
fn setup_progress(&self) -> Option<ProgressBar> {
|
||||
if self.show_progress {
|
||||
|
Reference in New Issue
Block a user