tweak builder pattern to find defaults in Builder::Default (#42)

* alternate builder pattern

* BpeBuilder

* update BpeTrainerBuilder
This commit is contained in:
Evan Pete Walsh
2020-01-08 12:10:34 -08:00
committed by GitHub
parent c7d2800131
commit d2d5b1eae7
4 changed files with 121 additions and 127 deletions

View File

@ -1,4 +1,4 @@
use super::{Cache, Error, Pair, WithFirstLastIterator, Word};
use super::{Cache, Error, Pair, WithFirstLastIterator, Word, DEFAULT_CACHE_CAPACITY};
use crate::tokenizer::{Model, Offsets, Result, Token};
use rand::{thread_rng, Rng};
use serde_json::Value;
@ -10,11 +10,10 @@ use std::{
path::{Path, PathBuf},
};
#[derive(Default)]
struct Config {
vocab: Option<HashMap<String, u32>>,
merges: Option<HashMap<Pair, (u32, u32)>>,
cache_capacity: Option<usize>,
vocab: HashMap<String, u32>,
merges: HashMap<Pair, (u32, u32)>,
cache_capacity: usize,
dropout: Option<f32>,
unk_token: Option<String>,
continuing_subword_prefix: Option<String>,
@ -22,11 +21,26 @@ struct Config {
}
/// A `BpeBuilder` can be used to create a `BPE` model with a custom configuration.
#[derive(Default)]
pub struct BpeBuilder {
config: Config,
}
impl Default for BpeBuilder {
fn default() -> Self {
Self {
config: Config {
vocab: HashMap::new(),
merges: HashMap::new(),
cache_capacity: DEFAULT_CACHE_CAPACITY,
dropout: None,
unk_token: None,
continuing_subword_prefix: None,
end_of_word_suffix: None,
},
}
}
}
impl BpeBuilder {
/// Constructs a new `BpeBuilder`.
pub fn new() -> Self {
@ -39,14 +53,14 @@ impl BpeBuilder {
vocab: HashMap<String, u32>,
merges: HashMap<Pair, (u32, u32)>,
) -> Self {
self.config.vocab = Some(vocab);
self.config.merges = Some(merges);
self.config.vocab = vocab;
self.config.merges = merges;
self
}
/// Set the cache's capacity. If the capacity is set to 0, no cache will be used.
/// Set the cache's capacity. Set to 0 if you want to disable caching.
pub fn cache_capacity(mut self, capacity: usize) -> Self {
self.config.cache_capacity = Some(capacity);
self.config.cache_capacity = capacity;
self
}
@ -76,43 +90,34 @@ impl BpeBuilder {
/// Returns a `BPE` model that uses the `BpeBuilder`'s configuration.
pub fn build(self) -> Result<BPE> {
let mut bpe = BPE::default();
// Validate dropout.
if let Some(p) = self.config.dropout {
if p < 0.0 || p > 1.0 {
if p <= 0.0 || p > 1.0 {
return Err(Error::InvalidDropout.into());
} else {
bpe.dropout = Some(p);
}
}
if let Some(vocab) = self.config.vocab {
bpe.vocab_r = vocab
.iter()
.map(|(key, val)| (*val, key.to_owned()))
.collect();
bpe.vocab = vocab;
}
if let Some(merges) = self.config.merges {
bpe.merges = merges;
}
if let Some(capacity) = self.config.cache_capacity {
bpe.cache = match capacity {
0 => None,
_ => Some(Cache::new(capacity)),
};
}
if let Some(unk_token) = self.config.unk_token {
bpe.unk_token = Some(unk_token);
}
if let Some(continuing_subword_prefix) = self.config.continuing_subword_prefix {
bpe.continuing_subword_prefix = Some(continuing_subword_prefix);
}
if let Some(end_of_word_suffix) = self.config.end_of_word_suffix {
bpe.end_of_word_suffix = Some(end_of_word_suffix);
}
Ok(bpe)
let vocab_r = self
.config
.vocab
.iter()
.map(|(key, val)| (*val, key.to_owned()))
.collect();
let cache = match self.config.cache_capacity {
0 => None,
capacity => Some(Cache::new(capacity)),
};
Ok(BPE {
vocab: self.config.vocab,
vocab_r,
merges: self.config.merges,
cache,
dropout: self.config.dropout,
unk_token: self.config.unk_token,
continuing_subword_prefix: self.config.continuing_subword_prefix,
end_of_word_suffix: self.config.end_of_word_suffix,
})
}
}
@ -139,16 +144,7 @@ pub struct BPE {
impl Default for BPE {
fn default() -> Self {
Self {
vocab: HashMap::new(),
vocab_r: HashMap::new(),
merges: HashMap::new(),
cache: Some(Cache::default()),
dropout: None,
unk_token: None,
continuing_subword_prefix: None,
end_of_word_suffix: None,
}
Self::builder().build().unwrap()
}
}

View File

@ -5,25 +5,40 @@ use crate::tokenizer::{Model, Result, Trainer};
use indicatif::{ProgressBar, ProgressStyle};
use std::collections::{HashMap, HashSet};
#[derive(Default)]
struct Config {
min_frequency: Option<u32>,
vocab_size: Option<usize>,
show_progress: Option<bool>,
special_tokens: Option<Vec<String>>,
min_frequency: u32,
vocab_size: usize,
show_progress: bool,
special_tokens: Vec<String>,
limit_alphabet: Option<usize>,
initial_alphabet: Option<HashSet<char>>,
initial_alphabet: HashSet<char>,
continuing_subword_prefix: Option<String>,
end_of_word_suffix: Option<String>,
}
/// A `BpeTrainerBuilder` can be used to create a `BpeTrainer` with a custom
/// configuration.
#[derive(Default)]
pub struct BpeTrainerBuilder {
config: Config,
}
impl Default for BpeTrainerBuilder {
fn default() -> Self {
Self {
config: Config {
min_frequency: 0,
vocab_size: 30000,
show_progress: true,
special_tokens: vec![],
limit_alphabet: None,
initial_alphabet: HashSet::new(),
continuing_subword_prefix: None,
end_of_word_suffix: None,
},
}
}
}
impl BpeTrainerBuilder {
/// Constructs a new `BpeTrainerBuilder`
pub fn new() -> Self {
@ -32,25 +47,25 @@ impl BpeTrainerBuilder {
/// Set the expected minimum frequency
pub fn min_frequency(mut self, frequency: u32) -> Self {
self.config.min_frequency = Some(frequency);
self.config.min_frequency = frequency;
self
}
/// Set the vocabulary size
pub fn vocab_size(mut self, size: usize) -> Self {
self.config.vocab_size = Some(size);
self.config.vocab_size = size;
self
}
/// Set whether to show progress
pub fn show_progress(mut self, show: bool) -> Self {
self.config.show_progress = Some(show);
self.config.show_progress = show;
self
}
/// Set the special tokens
pub fn special_tokens(mut self, tokens: Vec<String>) -> Self {
self.config.special_tokens = Some(tokens);
self.config.special_tokens = tokens;
self
}
@ -62,7 +77,7 @@ impl BpeTrainerBuilder {
/// Set the initial alphabet
pub fn initial_alphabet(mut self, alphabet: HashSet<char>) -> Self {
self.config.initial_alphabet = Some(alphabet);
self.config.initial_alphabet = alphabet;
self
}
@ -79,29 +94,17 @@ impl BpeTrainerBuilder {
}
/// Constructs the final BpeTrainer
pub fn build(self) -> Result<BpeTrainer> {
let mut trainer = BpeTrainer::default();
if let Some(freq) = self.config.min_frequency {
trainer.min_frequency = freq;
pub fn build(self) -> BpeTrainer {
BpeTrainer {
min_frequency: self.config.min_frequency,
vocab_size: self.config.vocab_size,
show_progress: self.config.show_progress,
special_tokens: self.config.special_tokens,
limit_alphabet: self.config.limit_alphabet,
initial_alphabet: self.config.initial_alphabet,
continuing_subword_prefix: self.config.continuing_subword_prefix,
end_of_word_suffix: self.config.end_of_word_suffix,
}
if let Some(vocab_size) = self.config.vocab_size {
trainer.vocab_size = vocab_size;
}
if let Some(show) = self.config.show_progress {
trainer.show_progress = show;
}
if let Some(special_tokens) = self.config.special_tokens {
trainer.special_tokens = special_tokens;
}
if let Some(alphabet) = self.config.initial_alphabet {
trainer.initial_alphabet = alphabet;
}
trainer.limit_alphabet = self.config.limit_alphabet;
trainer.continuing_subword_prefix = self.config.continuing_subword_prefix;
trainer.end_of_word_suffix = self.config.end_of_word_suffix;
Ok(trainer)
}
}
@ -143,16 +146,7 @@ pub struct BpeTrainer {
impl Default for BpeTrainer {
fn default() -> Self {
Self {
min_frequency: 0,
vocab_size: 30000,
show_progress: true,
special_tokens: vec![],
limit_alphabet: None,
initial_alphabet: HashSet::new(),
continuing_subword_prefix: None,
end_of_word_suffix: None,
}
Self::builder().build()
}
}

View File

@ -32,20 +32,31 @@ impl fmt::Display for Error {
}
}
#[derive(Default)]
struct Config {
vocab: Option<HashMap<String, u32>>,
unk_token: Option<String>,
continuing_subword_prefix: Option<String>,
max_input_chars_per_word: Option<usize>,
vocab: HashMap<String, u32>,
unk_token: String,
continuing_subword_prefix: String,
max_input_chars_per_word: usize,
}
/// A `WordPieceBuilder` can be used to create a `WordPiece` model with a custom configuration.
#[derive(Default)]
pub struct WordPieceBuilder {
config: Config,
}
impl Default for WordPieceBuilder {
fn default() -> Self {
Self {
config: Config {
vocab: HashMap::new(),
unk_token: String::from("[UNK]"),
continuing_subword_prefix: String::from("##"),
max_input_chars_per_word: 100,
},
}
}
}
impl WordPieceBuilder {
/// Construct a new `WordPieceBuilder`.
pub fn new() -> Self {
@ -54,50 +65,43 @@ impl WordPieceBuilder {
/// Set the vocab (token -> ID) mapping.
pub fn vocab(mut self, vocab: HashMap<String, u32>) -> Self {
self.config.vocab = Some(vocab);
self.config.vocab = vocab;
self
}
/// The the `UNK` token for the vocab.
pub fn unk_token(mut self, unk_token: String) -> Self {
self.config.unk_token = Some(unk_token);
self.config.unk_token = unk_token;
self
}
/// Set the prefix for continuing subwords.
pub fn continuing_subword_prefix(mut self, continuing_subword_prefix: String) -> Self {
self.config.continuing_subword_prefix = Some(continuing_subword_prefix);
self.config.continuing_subword_prefix = continuing_subword_prefix;
self
}
/// Set the maximum number of input characters per word.
pub fn max_input_chars_per_word(mut self, max_input_chars_per_word: usize) -> Self {
self.config.max_input_chars_per_word = Some(max_input_chars_per_word);
self.config.max_input_chars_per_word = max_input_chars_per_word;
self
}
/// Contructs a `WordPiece` model that uses the `WordPieceBuilder`'s configuration.
pub fn build(self) -> WordPiece {
let mut wp = WordPiece::default();
if let Some(vocab) = self.config.vocab {
wp.vocab_r = vocab
.iter()
.map(|(key, val)| (*val, key.to_owned()))
.collect();
wp.vocab = vocab;
let vocab_r = self
.config
.vocab
.iter()
.map(|(key, val)| (*val, key.to_owned()))
.collect();
WordPiece {
vocab: self.config.vocab,
vocab_r,
unk_token: self.config.unk_token,
continuing_subword_prefix: self.config.continuing_subword_prefix,
max_input_chars_per_word: self.config.max_input_chars_per_word,
}
if let Some(unk_token) = self.config.unk_token {
wp.unk_token = unk_token;
}
if let Some(continuing_subword_prefix) = self.config.continuing_subword_prefix {
wp.continuing_subword_prefix = continuing_subword_prefix;
}
if let Some(max_input_chars_per_word) = self.config.max_input_chars_per_word {
wp.max_input_chars_per_word = max_input_chars_per_word;
}
wp
}
}

View File

@ -72,9 +72,9 @@ impl WordPieceTrainerBuilder {
}
/// Constructs the final BpeTrainer
pub fn build(self) -> Result<WordPieceTrainer> {
let bpe_trainer = self.bpe_trainer_builder.build()?;
Ok(WordPieceTrainer { bpe_trainer })
pub fn build(self) -> WordPieceTrainer {
let bpe_trainer = self.bpe_trainer_builder.build();
WordPieceTrainer { bpe_trainer }
}
}