Add WordPieceTrainer

This commit is contained in:
Anthony MOI
2020-01-03 16:27:36 -05:00
parent 1bfe9fd0a7
commit 1dda76659f
3 changed files with 111 additions and 6 deletions

View File

@ -319,11 +319,8 @@ impl BpeTrainer {
(words, counts)
}
}
impl Trainer for BpeTrainer {
/// Train a BPE model
fn train(&self, word_counts: HashMap<String, u32>) -> Result<Box<dyn Model + Sync>> {
pub fn train(&self, word_counts: HashMap<String, u32>) -> Result<BPE> {
let mut word_to_id: HashMap<String, u32> = HashMap::new();
let mut id_to_word: Vec<String> = vec![];
@ -463,14 +460,22 @@ impl Trainer for BpeTrainer {
}
self.finalize_progress(&progress, merges.len());
Ok(Box::new(BPE::new(
Ok(BPE::new(
word_to_id,
merges
.into_iter()
.enumerate()
.map(|(index, (pair, new_id))| (pair, (index as u32, new_id)))
.collect(),
)))
))
}
}
impl Trainer for BpeTrainer {
/// Train a BPE model
fn train(&self, word_counts: HashMap<String, u32>) -> Result<Box<dyn Model + Sync>> {
let bpe = self.train(word_counts)?;
Ok(Box::new(bpe))
}
/// Process a bunch of tokens, counting them

View File

@ -9,6 +9,9 @@ use std::{
path::{Path, PathBuf},
};
mod trainer;
pub use trainer::*;
#[derive(Debug)]
pub enum Error {
MissingUnkToken,

View File

@ -0,0 +1,97 @@
use super::WordPiece;
use crate::models::bpe::{BpeTrainer, BpeTrainerBuilder};
use crate::tokenizer::{Model, Result, Trainer};
use std::collections::{HashMap, HashSet};
#[derive(Default)]
pub struct WordPieceTrainerBuilder {
bpe_trainer_builder: BpeTrainerBuilder,
}
impl WordPieceTrainerBuilder {
/// Constructs a new `WordPieceTrainerBuilder`
pub fn new() -> Self {
Self::default()
}
/// Set the expected minimum frequency
pub fn min_frequency(mut self, frequency: u32) -> Self {
self.bpe_trainer_builder = self.bpe_trainer_builder.min_frequency(frequency);
self
}
/// Set the vocabulary size
pub fn vocab_size(mut self, size: usize) -> Self {
self.bpe_trainer_builder = self.bpe_trainer_builder.vocab_size(size);
self
}
/// Set whether to show progress
pub fn show_progress(mut self, show: bool) -> Self {
self.bpe_trainer_builder = self.bpe_trainer_builder.show_progress(show);
self
}
/// Set the special tokens
pub fn special_tokens(mut self, tokens: Vec<String>) -> Self {
self.bpe_trainer_builder = self.bpe_trainer_builder.special_tokens(tokens);
self
}
/// Set whether to limit the alphabet
pub fn limit_alphabet(mut self, limit: usize) -> Self {
self.bpe_trainer_builder = self.bpe_trainer_builder.limit_alphabet(limit);
self
}
/// Set the initial alphabet
pub fn initial_alphabet(mut self, alphabet: HashSet<char>) -> Self {
self.bpe_trainer_builder = self.bpe_trainer_builder.initial_alphabet(alphabet);
self
}
/// Set the continuing_subword_prefix
pub fn continuing_subword_prefix(mut self, prefix: String) -> Self {
self.bpe_trainer_builder = self.bpe_trainer_builder.continuing_subword_prefix(prefix);
self
}
/// Set the end_of_word_suffix
pub fn end_of_word_suffix(mut self, suffix: String) -> Self {
self.bpe_trainer_builder = self.bpe_trainer_builder.end_of_word_suffix(suffix);
self
}
/// Constructs the final BpeTrainer
pub fn build(self) -> Result<WordPieceTrainer> {
let bpe_trainer = self.bpe_trainer_builder.build()?;
Ok(WordPieceTrainer { bpe_trainer })
}
}
#[derive(Default)]
pub struct WordPieceTrainer {
bpe_trainer: BpeTrainer,
}
impl WordPieceTrainer {
pub fn builder() -> WordPieceTrainerBuilder {
WordPieceTrainerBuilder::default()
}
pub fn train(&self, word_counts: HashMap<String, u32>) -> Result<WordPiece> {
let bpe = self.bpe_trainer.train(word_counts)?;
Ok(WordPiece::from_bpe(&bpe))
}
}
impl Trainer for WordPieceTrainer {
fn train(&self, word_counts: HashMap<String, u32>) -> Result<Box<dyn Model + Sync>> {
let wp = self.train(word_counts)?;
Ok(Box::new(wp))
}
fn process_tokens(&self, mut words: &mut HashMap<String, u32>, tokens: Vec<String>) {
self.bpe_trainer.process_tokens(&mut words, tokens)
}
}