diff --git a/bindings/node/native/src/lib.rs b/bindings/node/native/src/lib.rs index efaab045..09b2a1b9 100644 --- a/bindings/node/native/src/lib.rs +++ b/bindings/node/native/src/lib.rs @@ -9,6 +9,7 @@ mod pre_tokenizers; mod processors; mod tasks; mod tokenizer; +mod trainers; mod utils; use neon::prelude::*; @@ -26,6 +27,8 @@ register_module!(mut m, { normalizers::register(&mut m, "normalizers")?; // PreTokenizers pre_tokenizers::register(&mut m, "pre_tokenizers")?; + // Trainers + trainers::register(&mut m, "trainers")?; Ok(()) }); diff --git a/bindings/node/native/src/trainers.rs b/bindings/node/native/src/trainers.rs new file mode 100644 index 00000000..95590988 --- /dev/null +++ b/bindings/node/native/src/trainers.rs @@ -0,0 +1,203 @@ +extern crate tokenizers as tk; + +use crate::utils::Container; +use neon::prelude::*; +use std::collections::HashSet; + +/// Trainer +pub struct Trainer { + pub trainer: Container, +} + +declare_types! { + pub class JsTrainer for Trainer { + init(_) { + // This should not be called from JS + Ok(Trainer { + trainer: Container::Empty + }) + } + } +} + +/// bpe_trainer(options?: { +/// vocabSize?: number = 30000, +/// minFrequency?: number = 2, +/// specialTokens?: string[] = [], +/// limitAlphabet?: number = undefined, +/// initialAlphabet?: string[] = [], +/// showProgress?: bool = true, +/// continuingSubwordPrefix?: string = undefined, +/// endOfWordSuffix?: string = undefined, +/// }) +fn bpe_trainer(mut cx: FunctionContext) -> JsResult { + let options = cx.argument_opt(0); + + let mut builder = tk::models::bpe::BpeTrainer::builder(); + + if let Some(options) = options { + if let Ok(options) = options.downcast::() { + if let Ok(size) = options.get(&mut cx, "vocabSize") { + builder = builder + .vocab_size(size.downcast::().or_throw(&mut cx)?.value() as usize); + } + if let Ok(freq) = options.get(&mut cx, "minFrequency") { + builder = builder + .min_frequency(freq.downcast::().or_throw(&mut cx)?.value() as u32); + } + if let Ok(tokens) = options.get(&mut cx, "specialTokens") { + builder = builder.special_tokens( + tokens + .downcast::() + .or_throw(&mut cx)? + .to_vec(&mut cx)? + .into_iter() + .map(|token| Ok(token.downcast::().or_throw(&mut cx)?.value())) + .collect::>>()?, + ); + } + if let Ok(limit) = options.get(&mut cx, "limitAlphabet") { + builder = builder.limit_alphabet( + limit.downcast::().or_throw(&mut cx)?.value() as usize, + ); + } + if let Ok(alphabet) = options.get(&mut cx, "initialAlphabet") { + builder = builder.initial_alphabet( + alphabet + .downcast::() + .or_throw(&mut cx)? + .to_vec(&mut cx)? + .into_iter() + .map(|tokens| { + Ok(tokens + .downcast::() + .or_throw(&mut cx)? + .value() + .chars() + .nth(0)) + }) + .collect::>>()? + .into_iter() + .filter(|c| c.is_some()) + .map(|c| c.unwrap()) + .collect::>(), + ); + } + if let Ok(show) = options.get(&mut cx, "showProgress") { + builder = + builder.show_progress(show.downcast::().or_throw(&mut cx)?.value()); + } + if let Ok(prefix) = options.get(&mut cx, "continuingSubwordPrefix") { + builder = builder.continuing_subword_prefix( + prefix.downcast::().or_throw(&mut cx)?.value(), + ); + } + if let Ok(suffix) = options.get(&mut cx, "endOfWordSuffix") { + builder = builder + .end_of_word_suffix(suffix.downcast::().or_throw(&mut cx)?.value()); + } + } + } + + let mut trainer = JsTrainer::new::<_, JsTrainer, _>(&mut cx, vec![])?; + let guard = cx.lock(); + trainer + .borrow_mut(&guard) + .trainer + .to_owned(Box::new(builder.build())); + Ok(trainer) +} +/// wordpiece_trainer(options?: { +/// vocabSize?: number = 30000, +/// minFrequency?: number = 2, +/// specialTokens?: string[] = [], +/// limitAlphabet?: number = undefined, +/// initialAlphabet?: string[] = [], +/// showProgress?: bool = true, +/// continuingSubwordPrefix?: string = undefined, +/// endOfWordSuffix?: string = undefined, +/// }) +fn wordpiece_trainer(mut cx: FunctionContext) -> JsResult { + let options = cx.argument_opt(0); + + let mut builder = tk::models::wordpiece::WordPieceTrainer::builder(); + + if let Some(options) = options { + if let Ok(options) = options.downcast::() { + if let Ok(size) = options.get(&mut cx, "vocabSize") { + builder = builder + .vocab_size(size.downcast::().or_throw(&mut cx)?.value() as usize); + } + if let Ok(freq) = options.get(&mut cx, "minFrequency") { + builder = builder + .min_frequency(freq.downcast::().or_throw(&mut cx)?.value() as u32); + } + if let Ok(tokens) = options.get(&mut cx, "specialTokens") { + builder = builder.special_tokens( + tokens + .downcast::() + .or_throw(&mut cx)? + .to_vec(&mut cx)? + .into_iter() + .map(|token| Ok(token.downcast::().or_throw(&mut cx)?.value())) + .collect::>>()?, + ); + } + if let Ok(limit) = options.get(&mut cx, "limitAlphabet") { + builder = builder.limit_alphabet( + limit.downcast::().or_throw(&mut cx)?.value() as usize, + ); + } + if let Ok(alphabet) = options.get(&mut cx, "initialAlphabet") { + builder = builder.initial_alphabet( + alphabet + .downcast::() + .or_throw(&mut cx)? + .to_vec(&mut cx)? + .into_iter() + .map(|tokens| { + Ok(tokens + .downcast::() + .or_throw(&mut cx)? + .value() + .chars() + .nth(0)) + }) + .collect::>>()? + .into_iter() + .filter(|c| c.is_some()) + .map(|c| c.unwrap()) + .collect::>(), + ); + } + if let Ok(show) = options.get(&mut cx, "showProgress") { + builder = + builder.show_progress(show.downcast::().or_throw(&mut cx)?.value()); + } + if let Ok(prefix) = options.get(&mut cx, "continuingSubwordPrefix") { + builder = builder.continuing_subword_prefix( + prefix.downcast::().or_throw(&mut cx)?.value(), + ); + } + if let Ok(suffix) = options.get(&mut cx, "endOfWordSuffix") { + builder = builder + .end_of_word_suffix(suffix.downcast::().or_throw(&mut cx)?.value()); + } + } + } + + let mut trainer = JsTrainer::new::<_, JsTrainer, _>(&mut cx, vec![])?; + let guard = cx.lock(); + trainer + .borrow_mut(&guard) + .trainer + .to_owned(Box::new(builder.build())); + Ok(trainer) +} + +/// Register everything here +pub fn register(m: &mut ModuleContext, prefix: &str) -> NeonResult<()> { + m.export_function(&format!("{}_BPETrainer", prefix), bpe_trainer)?; + m.export_function(&format!("{}_WordPieceTrainer", prefix), wordpiece_trainer)?; + Ok(()) +}