mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-23 16:49:27 +00:00
Add WordPieceTrainer
This commit is contained in:
@ -319,11 +319,8 @@ impl BpeTrainer {
|
|||||||
|
|
||||||
(words, counts)
|
(words, counts)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
impl Trainer for BpeTrainer {
|
pub fn train(&self, word_counts: HashMap<String, u32>) -> Result<BPE> {
|
||||||
/// Train a BPE model
|
|
||||||
fn train(&self, word_counts: HashMap<String, u32>) -> Result<Box<dyn Model + Sync>> {
|
|
||||||
let mut word_to_id: HashMap<String, u32> = HashMap::new();
|
let mut word_to_id: HashMap<String, u32> = HashMap::new();
|
||||||
let mut id_to_word: Vec<String> = vec![];
|
let mut id_to_word: Vec<String> = vec![];
|
||||||
|
|
||||||
@ -463,14 +460,22 @@ impl Trainer for BpeTrainer {
|
|||||||
}
|
}
|
||||||
self.finalize_progress(&progress, merges.len());
|
self.finalize_progress(&progress, merges.len());
|
||||||
|
|
||||||
Ok(Box::new(BPE::new(
|
Ok(BPE::new(
|
||||||
word_to_id,
|
word_to_id,
|
||||||
merges
|
merges
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.enumerate()
|
.enumerate()
|
||||||
.map(|(index, (pair, new_id))| (pair, (index as u32, new_id)))
|
.map(|(index, (pair, new_id))| (pair, (index as u32, new_id)))
|
||||||
.collect(),
|
.collect(),
|
||||||
)))
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Trainer for BpeTrainer {
|
||||||
|
/// Train a BPE model
|
||||||
|
fn train(&self, word_counts: HashMap<String, u32>) -> Result<Box<dyn Model + Sync>> {
|
||||||
|
let bpe = self.train(word_counts)?;
|
||||||
|
Ok(Box::new(bpe))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Process a bunch of tokens, counting them
|
/// Process a bunch of tokens, counting them
|
||||||
|
@ -9,6 +9,9 @@ use std::{
|
|||||||
path::{Path, PathBuf},
|
path::{Path, PathBuf},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
mod trainer;
|
||||||
|
pub use trainer::*;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub enum Error {
|
pub enum Error {
|
||||||
MissingUnkToken,
|
MissingUnkToken,
|
||||||
|
97
tokenizers/src/models/wordpiece/trainer.rs
Normal file
97
tokenizers/src/models/wordpiece/trainer.rs
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
use super::WordPiece;
|
||||||
|
use crate::models::bpe::{BpeTrainer, BpeTrainerBuilder};
|
||||||
|
use crate::tokenizer::{Model, Result, Trainer};
|
||||||
|
use std::collections::{HashMap, HashSet};
|
||||||
|
|
||||||
|
#[derive(Default)]
|
||||||
|
pub struct WordPieceTrainerBuilder {
|
||||||
|
bpe_trainer_builder: BpeTrainerBuilder,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WordPieceTrainerBuilder {
|
||||||
|
/// Constructs a new `WordPieceTrainerBuilder`
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self::default()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set the expected minimum frequency
|
||||||
|
pub fn min_frequency(mut self, frequency: u32) -> Self {
|
||||||
|
self.bpe_trainer_builder = self.bpe_trainer_builder.min_frequency(frequency);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set the vocabulary size
|
||||||
|
pub fn vocab_size(mut self, size: usize) -> Self {
|
||||||
|
self.bpe_trainer_builder = self.bpe_trainer_builder.vocab_size(size);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set whether to show progress
|
||||||
|
pub fn show_progress(mut self, show: bool) -> Self {
|
||||||
|
self.bpe_trainer_builder = self.bpe_trainer_builder.show_progress(show);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set the special tokens
|
||||||
|
pub fn special_tokens(mut self, tokens: Vec<String>) -> Self {
|
||||||
|
self.bpe_trainer_builder = self.bpe_trainer_builder.special_tokens(tokens);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set whether to limit the alphabet
|
||||||
|
pub fn limit_alphabet(mut self, limit: usize) -> Self {
|
||||||
|
self.bpe_trainer_builder = self.bpe_trainer_builder.limit_alphabet(limit);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set the initial alphabet
|
||||||
|
pub fn initial_alphabet(mut self, alphabet: HashSet<char>) -> Self {
|
||||||
|
self.bpe_trainer_builder = self.bpe_trainer_builder.initial_alphabet(alphabet);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set the continuing_subword_prefix
|
||||||
|
pub fn continuing_subword_prefix(mut self, prefix: String) -> Self {
|
||||||
|
self.bpe_trainer_builder = self.bpe_trainer_builder.continuing_subword_prefix(prefix);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set the end_of_word_suffix
|
||||||
|
pub fn end_of_word_suffix(mut self, suffix: String) -> Self {
|
||||||
|
self.bpe_trainer_builder = self.bpe_trainer_builder.end_of_word_suffix(suffix);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Constructs the final BpeTrainer
|
||||||
|
pub fn build(self) -> Result<WordPieceTrainer> {
|
||||||
|
let bpe_trainer = self.bpe_trainer_builder.build()?;
|
||||||
|
Ok(WordPieceTrainer { bpe_trainer })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Default)]
|
||||||
|
pub struct WordPieceTrainer {
|
||||||
|
bpe_trainer: BpeTrainer,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WordPieceTrainer {
|
||||||
|
pub fn builder() -> WordPieceTrainerBuilder {
|
||||||
|
WordPieceTrainerBuilder::default()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn train(&self, word_counts: HashMap<String, u32>) -> Result<WordPiece> {
|
||||||
|
let bpe = self.bpe_trainer.train(word_counts)?;
|
||||||
|
Ok(WordPiece::from_bpe(&bpe))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Trainer for WordPieceTrainer {
|
||||||
|
fn train(&self, word_counts: HashMap<String, u32>) -> Result<Box<dyn Model + Sync>> {
|
||||||
|
let wp = self.train(word_counts)?;
|
||||||
|
Ok(Box::new(wp))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn process_tokens(&self, mut words: &mut HashMap<String, u32>, tokens: Vec<String>) {
|
||||||
|
self.bpe_trainer.process_tokens(&mut words, tokens)
|
||||||
|
}
|
||||||
|
}
|
Reference in New Issue
Block a user