mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
tweak builder pattern to find defaults in Builder::Default
(#42)
* alternate builder pattern * BpeBuilder * update BpeTrainerBuilder
This commit is contained in:
@ -1,4 +1,4 @@
|
||||
use super::{Cache, Error, Pair, WithFirstLastIterator, Word};
|
||||
use super::{Cache, Error, Pair, WithFirstLastIterator, Word, DEFAULT_CACHE_CAPACITY};
|
||||
use crate::tokenizer::{Model, Offsets, Result, Token};
|
||||
use rand::{thread_rng, Rng};
|
||||
use serde_json::Value;
|
||||
@ -10,11 +10,10 @@ use std::{
|
||||
path::{Path, PathBuf},
|
||||
};
|
||||
|
||||
#[derive(Default)]
|
||||
struct Config {
|
||||
vocab: Option<HashMap<String, u32>>,
|
||||
merges: Option<HashMap<Pair, (u32, u32)>>,
|
||||
cache_capacity: Option<usize>,
|
||||
vocab: HashMap<String, u32>,
|
||||
merges: HashMap<Pair, (u32, u32)>,
|
||||
cache_capacity: usize,
|
||||
dropout: Option<f32>,
|
||||
unk_token: Option<String>,
|
||||
continuing_subword_prefix: Option<String>,
|
||||
@ -22,11 +21,26 @@ struct Config {
|
||||
}
|
||||
|
||||
/// A `BpeBuilder` can be used to create a `BPE` model with a custom configuration.
|
||||
#[derive(Default)]
|
||||
pub struct BpeBuilder {
|
||||
config: Config,
|
||||
}
|
||||
|
||||
impl Default for BpeBuilder {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
config: Config {
|
||||
vocab: HashMap::new(),
|
||||
merges: HashMap::new(),
|
||||
cache_capacity: DEFAULT_CACHE_CAPACITY,
|
||||
dropout: None,
|
||||
unk_token: None,
|
||||
continuing_subword_prefix: None,
|
||||
end_of_word_suffix: None,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl BpeBuilder {
|
||||
/// Constructs a new `BpeBuilder`.
|
||||
pub fn new() -> Self {
|
||||
@ -39,14 +53,14 @@ impl BpeBuilder {
|
||||
vocab: HashMap<String, u32>,
|
||||
merges: HashMap<Pair, (u32, u32)>,
|
||||
) -> Self {
|
||||
self.config.vocab = Some(vocab);
|
||||
self.config.merges = Some(merges);
|
||||
self.config.vocab = vocab;
|
||||
self.config.merges = merges;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the cache's capacity. If the capacity is set to 0, no cache will be used.
|
||||
/// Set the cache's capacity. Set to 0 if you want to disable caching.
|
||||
pub fn cache_capacity(mut self, capacity: usize) -> Self {
|
||||
self.config.cache_capacity = Some(capacity);
|
||||
self.config.cache_capacity = capacity;
|
||||
self
|
||||
}
|
||||
|
||||
@ -76,43 +90,34 @@ impl BpeBuilder {
|
||||
|
||||
/// Returns a `BPE` model that uses the `BpeBuilder`'s configuration.
|
||||
pub fn build(self) -> Result<BPE> {
|
||||
let mut bpe = BPE::default();
|
||||
|
||||
// Validate dropout.
|
||||
if let Some(p) = self.config.dropout {
|
||||
if p < 0.0 || p > 1.0 {
|
||||
if p <= 0.0 || p > 1.0 {
|
||||
return Err(Error::InvalidDropout.into());
|
||||
} else {
|
||||
bpe.dropout = Some(p);
|
||||
}
|
||||
}
|
||||
if let Some(vocab) = self.config.vocab {
|
||||
bpe.vocab_r = vocab
|
||||
.iter()
|
||||
.map(|(key, val)| (*val, key.to_owned()))
|
||||
.collect();
|
||||
bpe.vocab = vocab;
|
||||
}
|
||||
if let Some(merges) = self.config.merges {
|
||||
bpe.merges = merges;
|
||||
}
|
||||
if let Some(capacity) = self.config.cache_capacity {
|
||||
bpe.cache = match capacity {
|
||||
0 => None,
|
||||
_ => Some(Cache::new(capacity)),
|
||||
};
|
||||
}
|
||||
if let Some(unk_token) = self.config.unk_token {
|
||||
bpe.unk_token = Some(unk_token);
|
||||
}
|
||||
if let Some(continuing_subword_prefix) = self.config.continuing_subword_prefix {
|
||||
bpe.continuing_subword_prefix = Some(continuing_subword_prefix);
|
||||
}
|
||||
if let Some(end_of_word_suffix) = self.config.end_of_word_suffix {
|
||||
bpe.end_of_word_suffix = Some(end_of_word_suffix);
|
||||
}
|
||||
|
||||
Ok(bpe)
|
||||
let vocab_r = self
|
||||
.config
|
||||
.vocab
|
||||
.iter()
|
||||
.map(|(key, val)| (*val, key.to_owned()))
|
||||
.collect();
|
||||
let cache = match self.config.cache_capacity {
|
||||
0 => None,
|
||||
capacity => Some(Cache::new(capacity)),
|
||||
};
|
||||
|
||||
Ok(BPE {
|
||||
vocab: self.config.vocab,
|
||||
vocab_r,
|
||||
merges: self.config.merges,
|
||||
cache,
|
||||
dropout: self.config.dropout,
|
||||
unk_token: self.config.unk_token,
|
||||
continuing_subword_prefix: self.config.continuing_subword_prefix,
|
||||
end_of_word_suffix: self.config.end_of_word_suffix,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -139,16 +144,7 @@ pub struct BPE {
|
||||
|
||||
impl Default for BPE {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
vocab: HashMap::new(),
|
||||
vocab_r: HashMap::new(),
|
||||
merges: HashMap::new(),
|
||||
cache: Some(Cache::default()),
|
||||
dropout: None,
|
||||
unk_token: None,
|
||||
continuing_subword_prefix: None,
|
||||
end_of_word_suffix: None,
|
||||
}
|
||||
Self::builder().build().unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -5,25 +5,40 @@ use crate::tokenizer::{Model, Result, Trainer};
|
||||
use indicatif::{ProgressBar, ProgressStyle};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
#[derive(Default)]
|
||||
struct Config {
|
||||
min_frequency: Option<u32>,
|
||||
vocab_size: Option<usize>,
|
||||
show_progress: Option<bool>,
|
||||
special_tokens: Option<Vec<String>>,
|
||||
min_frequency: u32,
|
||||
vocab_size: usize,
|
||||
show_progress: bool,
|
||||
special_tokens: Vec<String>,
|
||||
limit_alphabet: Option<usize>,
|
||||
initial_alphabet: Option<HashSet<char>>,
|
||||
initial_alphabet: HashSet<char>,
|
||||
continuing_subword_prefix: Option<String>,
|
||||
end_of_word_suffix: Option<String>,
|
||||
}
|
||||
|
||||
/// A `BpeTrainerBuilder` can be used to create a `BpeTrainer` with a custom
|
||||
/// configuration.
|
||||
#[derive(Default)]
|
||||
pub struct BpeTrainerBuilder {
|
||||
config: Config,
|
||||
}
|
||||
|
||||
impl Default for BpeTrainerBuilder {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
config: Config {
|
||||
min_frequency: 0,
|
||||
vocab_size: 30000,
|
||||
show_progress: true,
|
||||
special_tokens: vec![],
|
||||
limit_alphabet: None,
|
||||
initial_alphabet: HashSet::new(),
|
||||
continuing_subword_prefix: None,
|
||||
end_of_word_suffix: None,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl BpeTrainerBuilder {
|
||||
/// Constructs a new `BpeTrainerBuilder`
|
||||
pub fn new() -> Self {
|
||||
@ -32,25 +47,25 @@ impl BpeTrainerBuilder {
|
||||
|
||||
/// Set the expected minimum frequency
|
||||
pub fn min_frequency(mut self, frequency: u32) -> Self {
|
||||
self.config.min_frequency = Some(frequency);
|
||||
self.config.min_frequency = frequency;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the vocabulary size
|
||||
pub fn vocab_size(mut self, size: usize) -> Self {
|
||||
self.config.vocab_size = Some(size);
|
||||
self.config.vocab_size = size;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set whether to show progress
|
||||
pub fn show_progress(mut self, show: bool) -> Self {
|
||||
self.config.show_progress = Some(show);
|
||||
self.config.show_progress = show;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the special tokens
|
||||
pub fn special_tokens(mut self, tokens: Vec<String>) -> Self {
|
||||
self.config.special_tokens = Some(tokens);
|
||||
self.config.special_tokens = tokens;
|
||||
self
|
||||
}
|
||||
|
||||
@ -62,7 +77,7 @@ impl BpeTrainerBuilder {
|
||||
|
||||
/// Set the initial alphabet
|
||||
pub fn initial_alphabet(mut self, alphabet: HashSet<char>) -> Self {
|
||||
self.config.initial_alphabet = Some(alphabet);
|
||||
self.config.initial_alphabet = alphabet;
|
||||
self
|
||||
}
|
||||
|
||||
@ -79,29 +94,17 @@ impl BpeTrainerBuilder {
|
||||
}
|
||||
|
||||
/// Constructs the final BpeTrainer
|
||||
pub fn build(self) -> Result<BpeTrainer> {
|
||||
let mut trainer = BpeTrainer::default();
|
||||
|
||||
if let Some(freq) = self.config.min_frequency {
|
||||
trainer.min_frequency = freq;
|
||||
pub fn build(self) -> BpeTrainer {
|
||||
BpeTrainer {
|
||||
min_frequency: self.config.min_frequency,
|
||||
vocab_size: self.config.vocab_size,
|
||||
show_progress: self.config.show_progress,
|
||||
special_tokens: self.config.special_tokens,
|
||||
limit_alphabet: self.config.limit_alphabet,
|
||||
initial_alphabet: self.config.initial_alphabet,
|
||||
continuing_subword_prefix: self.config.continuing_subword_prefix,
|
||||
end_of_word_suffix: self.config.end_of_word_suffix,
|
||||
}
|
||||
if let Some(vocab_size) = self.config.vocab_size {
|
||||
trainer.vocab_size = vocab_size;
|
||||
}
|
||||
if let Some(show) = self.config.show_progress {
|
||||
trainer.show_progress = show;
|
||||
}
|
||||
if let Some(special_tokens) = self.config.special_tokens {
|
||||
trainer.special_tokens = special_tokens;
|
||||
}
|
||||
if let Some(alphabet) = self.config.initial_alphabet {
|
||||
trainer.initial_alphabet = alphabet;
|
||||
}
|
||||
trainer.limit_alphabet = self.config.limit_alphabet;
|
||||
trainer.continuing_subword_prefix = self.config.continuing_subword_prefix;
|
||||
trainer.end_of_word_suffix = self.config.end_of_word_suffix;
|
||||
|
||||
Ok(trainer)
|
||||
}
|
||||
}
|
||||
|
||||
@ -143,16 +146,7 @@ pub struct BpeTrainer {
|
||||
|
||||
impl Default for BpeTrainer {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
min_frequency: 0,
|
||||
vocab_size: 30000,
|
||||
show_progress: true,
|
||||
special_tokens: vec![],
|
||||
limit_alphabet: None,
|
||||
initial_alphabet: HashSet::new(),
|
||||
continuing_subword_prefix: None,
|
||||
end_of_word_suffix: None,
|
||||
}
|
||||
Self::builder().build()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -32,20 +32,31 @@ impl fmt::Display for Error {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct Config {
|
||||
vocab: Option<HashMap<String, u32>>,
|
||||
unk_token: Option<String>,
|
||||
continuing_subword_prefix: Option<String>,
|
||||
max_input_chars_per_word: Option<usize>,
|
||||
vocab: HashMap<String, u32>,
|
||||
unk_token: String,
|
||||
continuing_subword_prefix: String,
|
||||
max_input_chars_per_word: usize,
|
||||
}
|
||||
|
||||
/// A `WordPieceBuilder` can be used to create a `WordPiece` model with a custom configuration.
|
||||
#[derive(Default)]
|
||||
pub struct WordPieceBuilder {
|
||||
config: Config,
|
||||
}
|
||||
|
||||
impl Default for WordPieceBuilder {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
config: Config {
|
||||
vocab: HashMap::new(),
|
||||
unk_token: String::from("[UNK]"),
|
||||
continuing_subword_prefix: String::from("##"),
|
||||
max_input_chars_per_word: 100,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl WordPieceBuilder {
|
||||
/// Construct a new `WordPieceBuilder`.
|
||||
pub fn new() -> Self {
|
||||
@ -54,50 +65,43 @@ impl WordPieceBuilder {
|
||||
|
||||
/// Set the vocab (token -> ID) mapping.
|
||||
pub fn vocab(mut self, vocab: HashMap<String, u32>) -> Self {
|
||||
self.config.vocab = Some(vocab);
|
||||
self.config.vocab = vocab;
|
||||
self
|
||||
}
|
||||
|
||||
/// The the `UNK` token for the vocab.
|
||||
pub fn unk_token(mut self, unk_token: String) -> Self {
|
||||
self.config.unk_token = Some(unk_token);
|
||||
self.config.unk_token = unk_token;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the prefix for continuing subwords.
|
||||
pub fn continuing_subword_prefix(mut self, continuing_subword_prefix: String) -> Self {
|
||||
self.config.continuing_subword_prefix = Some(continuing_subword_prefix);
|
||||
self.config.continuing_subword_prefix = continuing_subword_prefix;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the maximum number of input characters per word.
|
||||
pub fn max_input_chars_per_word(mut self, max_input_chars_per_word: usize) -> Self {
|
||||
self.config.max_input_chars_per_word = Some(max_input_chars_per_word);
|
||||
self.config.max_input_chars_per_word = max_input_chars_per_word;
|
||||
self
|
||||
}
|
||||
|
||||
/// Contructs a `WordPiece` model that uses the `WordPieceBuilder`'s configuration.
|
||||
pub fn build(self) -> WordPiece {
|
||||
let mut wp = WordPiece::default();
|
||||
|
||||
if let Some(vocab) = self.config.vocab {
|
||||
wp.vocab_r = vocab
|
||||
.iter()
|
||||
.map(|(key, val)| (*val, key.to_owned()))
|
||||
.collect();
|
||||
wp.vocab = vocab;
|
||||
let vocab_r = self
|
||||
.config
|
||||
.vocab
|
||||
.iter()
|
||||
.map(|(key, val)| (*val, key.to_owned()))
|
||||
.collect();
|
||||
WordPiece {
|
||||
vocab: self.config.vocab,
|
||||
vocab_r,
|
||||
unk_token: self.config.unk_token,
|
||||
continuing_subword_prefix: self.config.continuing_subword_prefix,
|
||||
max_input_chars_per_word: self.config.max_input_chars_per_word,
|
||||
}
|
||||
if let Some(unk_token) = self.config.unk_token {
|
||||
wp.unk_token = unk_token;
|
||||
}
|
||||
if let Some(continuing_subword_prefix) = self.config.continuing_subword_prefix {
|
||||
wp.continuing_subword_prefix = continuing_subword_prefix;
|
||||
}
|
||||
if let Some(max_input_chars_per_word) = self.config.max_input_chars_per_word {
|
||||
wp.max_input_chars_per_word = max_input_chars_per_word;
|
||||
}
|
||||
|
||||
wp
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -72,9 +72,9 @@ impl WordPieceTrainerBuilder {
|
||||
}
|
||||
|
||||
/// Constructs the final BpeTrainer
|
||||
pub fn build(self) -> Result<WordPieceTrainer> {
|
||||
let bpe_trainer = self.bpe_trainer_builder.build()?;
|
||||
Ok(WordPieceTrainer { bpe_trainer })
|
||||
pub fn build(self) -> WordPieceTrainer {
|
||||
let bpe_trainer = self.bpe_trainer_builder.build();
|
||||
WordPieceTrainer { bpe_trainer }
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user