diff --git a/bindings/node/lib/bindings/trainers.d.ts b/bindings/node/lib/bindings/trainers.d.ts index a08e6e14..357c096a 100644 --- a/bindings/node/lib/bindings/trainers.d.ts +++ b/bindings/node/lib/bindings/trainers.d.ts @@ -63,6 +63,35 @@ export function bpeTrainer(options?: TrainerOptions): Trainer; */ export function wordPieceTrainer(options?: TrainerOptions): Trainer; +export interface WordLevelTrainerOptions { + /** + * The minimum frequency a pair should have in order to be merged. + * @default 2 + */ + minFrequency?: number; + /** + * Whether to show progress bars while training. + * @default true + */ + showProgress?: boolean; + /** + * A list of special tokens the model should know of. + * @default [] + */ + specialTokens?: (string | AddedToken)[]; + /** + * The size of the final vocabulary, including all tokens and alphabet. + * @default 30000 + */ + vocabSize?: number; +} + +/** + * Instantiate a new WordLevel Trainer + * @param [options] WordLevel Trainer options + */ +export function wordLevelTrainer(options?: WordLevelTrainerOptions): Trainer; + export interface UnigramTrainerOptions { vocabSize?: number; nSubIterations?: number; diff --git a/bindings/node/lib/bindings/trainers.js b/bindings/node/lib/bindings/trainers.js index 9521d153..1a6d019b 100644 --- a/bindings/node/lib/bindings/trainers.js +++ b/bindings/node/lib/bindings/trainers.js @@ -3,5 +3,6 @@ const native = require("./native"); module.exports = { bpeTrainer: native.trainers_BPETrainer, wordPieceTrainer: native.trainers_WordPieceTrainer, + wordLevelTrainer: native.trainers_WordLevelTrainer, unigramTrainer: native.trainers_UnigramTrainer, }; diff --git a/bindings/node/native/src/trainers.rs b/bindings/node/native/src/trainers.rs index 93cba7cc..a58f77c1 100644 --- a/bindings/node/native/src/trainers.rs +++ b/bindings/node/native/src/trainers.rs @@ -8,7 +8,8 @@ use std::collections::HashMap; use std::sync::Arc; use tk::models::{ - bpe::BpeTrainer, unigram::UnigramTrainer, wordpiece::WordPieceTrainer, TrainerWrapper, + bpe::BpeTrainer, unigram::UnigramTrainer, wordlevel::WordLevelTrainer, + wordpiece::WordPieceTrainer, TrainerWrapper, }; /// Trainer @@ -17,6 +18,14 @@ pub struct Trainer { pub trainer: Option>, } +impl From for Trainer { + fn from(trainer: TrainerWrapper) -> Self { + Self { + trainer: Some(Arc::new(trainer)), + } + } +} + impl tk::Trainer for Trainer { type Model = Model; @@ -27,14 +36,26 @@ impl tk::Trainer for Trainer { .should_show_progress() } - fn train(&self, words: HashMap) -> tk::Result<(Self::Model, Vec)> { - let (model, special_tokens) = self + fn train( + &self, + words: HashMap, + model: &mut Self::Model, + ) -> tk::Result> { + let special_tokens = self .trainer .as_ref() .ok_or("Uninitialized Trainer")? - .train(words)?; + .train( + words, + &mut model + .model + .as_ref() + .ok_or("Uninitialized Model")? + .write() + .unwrap(), + )?; - Ok((model.into(), special_tokens)) + Ok(special_tokens) } fn process_tokens(&self, words: &mut HashMap, tokens: Vec) { @@ -238,6 +259,81 @@ fn wordpiece_trainer(mut cx: FunctionContext) -> JsResult { Ok(js_trainer) } +// WordLevel + +struct WordLevelTrainerOptions(WordLevelTrainer); +impl From for WordLevelTrainer { + fn from(v: WordLevelTrainerOptions) -> Self { + v.0 + } +} +impl FromJsValue for WordLevelTrainerOptions { + fn from_value<'c, C: Context<'c>>(from: Handle<'c, JsValue>, cx: &mut C) -> LibResult { + if let Ok(options) = from.downcast::() { + let mut builder = WordLevelTrainer::builder(); + + if let Ok(size) = options.get(cx, "vocabSize") { + if let Some(size) = Option::from_value(size, cx)? { + builder.vocab_size(size); + } + } + if let Ok(freq) = options.get(cx, "minFrequency") { + if let Some(freq) = Option::from_value(freq, cx)? { + builder.min_frequency(freq); + } + } + if let Ok(tokens) = options.get(cx, "specialTokens") { + if tokens.downcast::().is_err() && tokens.downcast::().is_err() + { + builder.special_tokens( + tokens + .downcast::() + .map_err(|e| Error(format!("{}", e)))? + .to_vec(cx)? + .into_iter() + .map(|token| Ok(AddedToken::from_value(token, cx)?.into())) + .collect::, Error>>()?, + ); + } + } + if let Ok(show) = options.get(cx, "showProgress") { + if let Some(show) = Option::from_value(show, cx)? { + builder.show_progress(show); + } + } + + Ok(Self( + builder + .build() + .expect("WordLevelTrainerBuilder cannot fail"), + )) + } else { + Err(Error("Expected options type: object".into())) + } + } +} + +/// wordlevel_trainer(options?: { +/// vocabSize?: number = 30000, +/// minFrequency?: number = 0, +/// specialTokens?: string[] = [], +/// showProgress?: bool = true, +/// }) +fn wordlevel_trainer(mut cx: FunctionContext) -> JsResult { + let trainer = cx.extract_opt::(0)?.map_or_else( + || WordLevelTrainer::builder().build().unwrap(), + |o| o.into(), + ); + + let mut js_trainer = JsTrainer::new::<_, JsTrainer, _>(&mut cx, vec![])?; + let guard = cx.lock(); + js_trainer.borrow_mut(&guard).trainer = Some(Arc::new(trainer.into())); + + Ok(js_trainer) +} + +// Unigram + struct UnigramTrainerOptions(UnigramTrainer); impl From for UnigramTrainer { fn from(v: UnigramTrainerOptions) -> Self { @@ -337,6 +433,7 @@ fn unigram_trainer(mut cx: FunctionContext) -> JsResult { pub fn register(m: &mut ModuleContext, prefix: &str) -> NeonResult<()> { m.export_function(&format!("{}_BPETrainer", prefix), bpe_trainer)?; m.export_function(&format!("{}_WordPieceTrainer", prefix), wordpiece_trainer)?; + m.export_function(&format!("{}_WordLevelTrainer", prefix), wordlevel_trainer)?; m.export_function(&format!("{}_UnigramTrainer", prefix), unigram_trainer)?; Ok(()) } diff --git a/bindings/python/src/trainers.rs b/bindings/python/src/trainers.rs index 404adc65..e72f1bd8 100644 --- a/bindings/python/src/trainers.rs +++ b/bindings/python/src/trainers.rs @@ -301,34 +301,41 @@ impl PyWordLevelTrainer { #[new] #[args(kwargs = "**")] pub fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PyTrainer)> { - let mut trainer = tk::models::wordlevel::WordLevelTrainer::default(); + let mut builder = tk::models::wordlevel::WordLevelTrainer::builder(); if let Some(kwargs) = kwargs { for (key, val) in kwargs { let key: &str = key.extract()?; match key { - "vocab_size" => trainer.vocab_size = val.extract()?, - "min_frequency" => trainer.min_frequency = val.extract()?, - "show_progress" => trainer.show_progress = val.extract()?, + "vocab_size" => { + builder.vocab_size(val.extract()?); + } + "min_frequency" => { + builder.min_frequency(val.extract()?); + } + "show_progress" => { + builder.show_progress(val.extract()?); + } "special_tokens" => { - trainer.special_tokens = val - .cast_as::()? - .into_iter() - .map(|token| { - if let Ok(content) = token.extract::() { - Ok(PyAddedToken::from(content, Some(true)).get_token()) - } else if let Ok(mut token) = - token.extract::>() - { - token.is_special_token = true; - Ok(token.get_token()) - } else { - Err(exceptions::PyTypeError::new_err( - "special_tokens must be a List[Union[str, AddedToken]]", - )) - } - }) - .collect::>>()? + builder.special_tokens( + val.cast_as::()? + .into_iter() + .map(|token| { + if let Ok(content) = token.extract::() { + Ok(PyAddedToken::from(content, Some(true)).get_token()) + } else if let Ok(mut token) = + token.extract::>() + { + token.is_special_token = true; + Ok(token.get_token()) + } else { + Err(exceptions::PyTypeError::new_err( + "special_tokens must be a List[Union[str, AddedToken]]", + )) + } + }) + .collect::>>()?, + ); } _ => println!("Ignored unknown kwargs option {}", key), } @@ -337,7 +344,12 @@ impl PyWordLevelTrainer { Ok(( PyWordLevelTrainer {}, - PyTrainer::new(Arc::new(trainer.into())), + PyTrainer::new(Arc::new( + builder + .build() + .expect("WordLevelTrainerBuilder cannot fail") + .into(), + )), )) } } diff --git a/tokenizers/src/models/wordlevel/trainer.rs b/tokenizers/src/models/wordlevel/trainer.rs index 2485b57c..c864529f 100644 --- a/tokenizers/src/models/wordlevel/trainer.rs +++ b/tokenizers/src/models/wordlevel/trainer.rs @@ -2,15 +2,20 @@ use super::WordLevel; use crate::{AddedToken, Result, Trainer}; use std::collections::HashMap; +#[derive(Debug, Clone, Builder)] pub struct WordLevelTrainer { /// The minimum frequency a word must have to be part of the vocabulary - pub min_frequency: u32, + #[builder(default)] + min_frequency: u32, /// The target vocabulary size - pub vocab_size: usize, + #[builder(default)] + vocab_size: usize, /// Whether to show progress while training - pub show_progress: bool, + #[builder(default)] + show_progress: bool, /// A list of special tokens that the model should know of - pub special_tokens: Vec, + #[builder(default)] + special_tokens: Vec, } impl Default for WordLevelTrainer { @@ -25,6 +30,10 @@ impl Default for WordLevelTrainer { } impl WordLevelTrainer { + pub fn builder() -> WordLevelTrainerBuilder { + WordLevelTrainerBuilder::default() + } + fn train( &self, word_counts: HashMap,