mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Add WordPieceTrainer
This commit is contained in:
@ -319,11 +319,8 @@ impl BpeTrainer {
|
||||
|
||||
(words, counts)
|
||||
}
|
||||
}
|
||||
|
||||
impl Trainer for BpeTrainer {
|
||||
/// Train a BPE model
|
||||
fn train(&self, word_counts: HashMap<String, u32>) -> Result<Box<dyn Model + Sync>> {
|
||||
pub fn train(&self, word_counts: HashMap<String, u32>) -> Result<BPE> {
|
||||
let mut word_to_id: HashMap<String, u32> = HashMap::new();
|
||||
let mut id_to_word: Vec<String> = vec![];
|
||||
|
||||
@ -463,14 +460,22 @@ impl Trainer for BpeTrainer {
|
||||
}
|
||||
self.finalize_progress(&progress, merges.len());
|
||||
|
||||
Ok(Box::new(BPE::new(
|
||||
Ok(BPE::new(
|
||||
word_to_id,
|
||||
merges
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(index, (pair, new_id))| (pair, (index as u32, new_id)))
|
||||
.collect(),
|
||||
)))
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
impl Trainer for BpeTrainer {
|
||||
/// Train a BPE model
|
||||
fn train(&self, word_counts: HashMap<String, u32>) -> Result<Box<dyn Model + Sync>> {
|
||||
let bpe = self.train(word_counts)?;
|
||||
Ok(Box::new(bpe))
|
||||
}
|
||||
|
||||
/// Process a bunch of tokens, counting them
|
||||
|
@ -9,6 +9,9 @@ use std::{
|
||||
path::{Path, PathBuf},
|
||||
};
|
||||
|
||||
mod trainer;
|
||||
pub use trainer::*;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum Error {
|
||||
MissingUnkToken,
|
||||
|
97
tokenizers/src/models/wordpiece/trainer.rs
Normal file
97
tokenizers/src/models/wordpiece/trainer.rs
Normal file
@ -0,0 +1,97 @@
|
||||
use super::WordPiece;
|
||||
use crate::models::bpe::{BpeTrainer, BpeTrainerBuilder};
|
||||
use crate::tokenizer::{Model, Result, Trainer};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct WordPieceTrainerBuilder {
|
||||
bpe_trainer_builder: BpeTrainerBuilder,
|
||||
}
|
||||
|
||||
impl WordPieceTrainerBuilder {
|
||||
/// Constructs a new `WordPieceTrainerBuilder`
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Set the expected minimum frequency
|
||||
pub fn min_frequency(mut self, frequency: u32) -> Self {
|
||||
self.bpe_trainer_builder = self.bpe_trainer_builder.min_frequency(frequency);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the vocabulary size
|
||||
pub fn vocab_size(mut self, size: usize) -> Self {
|
||||
self.bpe_trainer_builder = self.bpe_trainer_builder.vocab_size(size);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set whether to show progress
|
||||
pub fn show_progress(mut self, show: bool) -> Self {
|
||||
self.bpe_trainer_builder = self.bpe_trainer_builder.show_progress(show);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the special tokens
|
||||
pub fn special_tokens(mut self, tokens: Vec<String>) -> Self {
|
||||
self.bpe_trainer_builder = self.bpe_trainer_builder.special_tokens(tokens);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set whether to limit the alphabet
|
||||
pub fn limit_alphabet(mut self, limit: usize) -> Self {
|
||||
self.bpe_trainer_builder = self.bpe_trainer_builder.limit_alphabet(limit);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the initial alphabet
|
||||
pub fn initial_alphabet(mut self, alphabet: HashSet<char>) -> Self {
|
||||
self.bpe_trainer_builder = self.bpe_trainer_builder.initial_alphabet(alphabet);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the continuing_subword_prefix
|
||||
pub fn continuing_subword_prefix(mut self, prefix: String) -> Self {
|
||||
self.bpe_trainer_builder = self.bpe_trainer_builder.continuing_subword_prefix(prefix);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the end_of_word_suffix
|
||||
pub fn end_of_word_suffix(mut self, suffix: String) -> Self {
|
||||
self.bpe_trainer_builder = self.bpe_trainer_builder.end_of_word_suffix(suffix);
|
||||
self
|
||||
}
|
||||
|
||||
/// Constructs the final BpeTrainer
|
||||
pub fn build(self) -> Result<WordPieceTrainer> {
|
||||
let bpe_trainer = self.bpe_trainer_builder.build()?;
|
||||
Ok(WordPieceTrainer { bpe_trainer })
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct WordPieceTrainer {
|
||||
bpe_trainer: BpeTrainer,
|
||||
}
|
||||
|
||||
impl WordPieceTrainer {
|
||||
pub fn builder() -> WordPieceTrainerBuilder {
|
||||
WordPieceTrainerBuilder::default()
|
||||
}
|
||||
|
||||
pub fn train(&self, word_counts: HashMap<String, u32>) -> Result<WordPiece> {
|
||||
let bpe = self.bpe_trainer.train(word_counts)?;
|
||||
Ok(WordPiece::from_bpe(&bpe))
|
||||
}
|
||||
}
|
||||
|
||||
impl Trainer for WordPieceTrainer {
|
||||
fn train(&self, word_counts: HashMap<String, u32>) -> Result<Box<dyn Model + Sync>> {
|
||||
let wp = self.train(word_counts)?;
|
||||
Ok(Box::new(wp))
|
||||
}
|
||||
|
||||
fn process_tokens(&self, mut words: &mut HashMap<String, u32>, tokens: Vec<String>) {
|
||||
self.bpe_trainer.process_tokens(&mut words, tokens)
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user