Node - Add WordLevelTrainer

This commit is contained in:
Anthony MOI
2020-11-19 20:01:28 -05:00
committed by Anthony MOI
parent 7fc37a03e8
commit 13e07da2c8
5 changed files with 180 additions and 32 deletions

View File

@ -63,6 +63,35 @@ export function bpeTrainer(options?: TrainerOptions): Trainer;
*/
export function wordPieceTrainer(options?: TrainerOptions): Trainer;
export interface WordLevelTrainerOptions {
/**
* The minimum frequency a pair should have in order to be merged.
* @default 2
*/
minFrequency?: number;
/**
* Whether to show progress bars while training.
* @default true
*/
showProgress?: boolean;
/**
* A list of special tokens the model should know of.
* @default []
*/
specialTokens?: (string | AddedToken)[];
/**
* The size of the final vocabulary, including all tokens and alphabet.
* @default 30000
*/
vocabSize?: number;
}
/**
* Instantiate a new WordLevel Trainer
* @param [options] WordLevel Trainer options
*/
export function wordLevelTrainer(options?: WordLevelTrainerOptions): Trainer;
export interface UnigramTrainerOptions {
vocabSize?: number;
nSubIterations?: number;

View File

@ -3,5 +3,6 @@ const native = require("./native");
module.exports = {
bpeTrainer: native.trainers_BPETrainer,
wordPieceTrainer: native.trainers_WordPieceTrainer,
wordLevelTrainer: native.trainers_WordLevelTrainer,
unigramTrainer: native.trainers_UnigramTrainer,
};

View File

@ -8,7 +8,8 @@ use std::collections::HashMap;
use std::sync::Arc;
use tk::models::{
bpe::BpeTrainer, unigram::UnigramTrainer, wordpiece::WordPieceTrainer, TrainerWrapper,
bpe::BpeTrainer, unigram::UnigramTrainer, wordlevel::WordLevelTrainer,
wordpiece::WordPieceTrainer, TrainerWrapper,
};
/// Trainer
@ -17,6 +18,14 @@ pub struct Trainer {
pub trainer: Option<Arc<TrainerWrapper>>,
}
impl From<TrainerWrapper> for Trainer {
fn from(trainer: TrainerWrapper) -> Self {
Self {
trainer: Some(Arc::new(trainer)),
}
}
}
impl tk::Trainer for Trainer {
type Model = Model;
@ -27,14 +36,26 @@ impl tk::Trainer for Trainer {
.should_show_progress()
}
fn train(&self, words: HashMap<String, u32>) -> tk::Result<(Self::Model, Vec<tk::AddedToken>)> {
let (model, special_tokens) = self
fn train(
&self,
words: HashMap<String, u32>,
model: &mut Self::Model,
) -> tk::Result<Vec<tk::AddedToken>> {
let special_tokens = self
.trainer
.as_ref()
.ok_or("Uninitialized Trainer")?
.train(words)?;
.train(
words,
&mut model
.model
.as_ref()
.ok_or("Uninitialized Model")?
.write()
.unwrap(),
)?;
Ok((model.into(), special_tokens))
Ok(special_tokens)
}
fn process_tokens(&self, words: &mut HashMap<String, u32>, tokens: Vec<String>) {
@ -238,6 +259,81 @@ fn wordpiece_trainer(mut cx: FunctionContext) -> JsResult<JsTrainer> {
Ok(js_trainer)
}
// WordLevel
struct WordLevelTrainerOptions(WordLevelTrainer);
impl From<WordLevelTrainerOptions> for WordLevelTrainer {
fn from(v: WordLevelTrainerOptions) -> Self {
v.0
}
}
impl FromJsValue for WordLevelTrainerOptions {
fn from_value<'c, C: Context<'c>>(from: Handle<'c, JsValue>, cx: &mut C) -> LibResult<Self> {
if let Ok(options) = from.downcast::<JsObject>() {
let mut builder = WordLevelTrainer::builder();
if let Ok(size) = options.get(cx, "vocabSize") {
if let Some(size) = Option::from_value(size, cx)? {
builder.vocab_size(size);
}
}
if let Ok(freq) = options.get(cx, "minFrequency") {
if let Some(freq) = Option::from_value(freq, cx)? {
builder.min_frequency(freq);
}
}
if let Ok(tokens) = options.get(cx, "specialTokens") {
if tokens.downcast::<JsNull>().is_err() && tokens.downcast::<JsUndefined>().is_err()
{
builder.special_tokens(
tokens
.downcast::<JsArray>()
.map_err(|e| Error(format!("{}", e)))?
.to_vec(cx)?
.into_iter()
.map(|token| Ok(AddedToken::from_value(token, cx)?.into()))
.collect::<Result<Vec<_>, Error>>()?,
);
}
}
if let Ok(show) = options.get(cx, "showProgress") {
if let Some(show) = Option::from_value(show, cx)? {
builder.show_progress(show);
}
}
Ok(Self(
builder
.build()
.expect("WordLevelTrainerBuilder cannot fail"),
))
} else {
Err(Error("Expected options type: object".into()))
}
}
}
/// wordlevel_trainer(options?: {
/// vocabSize?: number = 30000,
/// minFrequency?: number = 0,
/// specialTokens?: string[] = [],
/// showProgress?: bool = true,
/// })
fn wordlevel_trainer(mut cx: FunctionContext) -> JsResult<JsTrainer> {
let trainer = cx.extract_opt::<WordLevelTrainerOptions>(0)?.map_or_else(
|| WordLevelTrainer::builder().build().unwrap(),
|o| o.into(),
);
let mut js_trainer = JsTrainer::new::<_, JsTrainer, _>(&mut cx, vec![])?;
let guard = cx.lock();
js_trainer.borrow_mut(&guard).trainer = Some(Arc::new(trainer.into()));
Ok(js_trainer)
}
// Unigram
struct UnigramTrainerOptions(UnigramTrainer);
impl From<UnigramTrainerOptions> for UnigramTrainer {
fn from(v: UnigramTrainerOptions) -> Self {
@ -337,6 +433,7 @@ fn unigram_trainer(mut cx: FunctionContext) -> JsResult<JsTrainer> {
pub fn register(m: &mut ModuleContext, prefix: &str) -> NeonResult<()> {
m.export_function(&format!("{}_BPETrainer", prefix), bpe_trainer)?;
m.export_function(&format!("{}_WordPieceTrainer", prefix), wordpiece_trainer)?;
m.export_function(&format!("{}_WordLevelTrainer", prefix), wordlevel_trainer)?;
m.export_function(&format!("{}_UnigramTrainer", prefix), unigram_trainer)?;
Ok(())
}

View File

@ -301,34 +301,41 @@ impl PyWordLevelTrainer {
#[new]
#[args(kwargs = "**")]
pub fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PyTrainer)> {
let mut trainer = tk::models::wordlevel::WordLevelTrainer::default();
let mut builder = tk::models::wordlevel::WordLevelTrainer::builder();
if let Some(kwargs) = kwargs {
for (key, val) in kwargs {
let key: &str = key.extract()?;
match key {
"vocab_size" => trainer.vocab_size = val.extract()?,
"min_frequency" => trainer.min_frequency = val.extract()?,
"show_progress" => trainer.show_progress = val.extract()?,
"vocab_size" => {
builder.vocab_size(val.extract()?);
}
"min_frequency" => {
builder.min_frequency(val.extract()?);
}
"show_progress" => {
builder.show_progress(val.extract()?);
}
"special_tokens" => {
trainer.special_tokens = val
.cast_as::<PyList>()?
.into_iter()
.map(|token| {
if let Ok(content) = token.extract::<String>() {
Ok(PyAddedToken::from(content, Some(true)).get_token())
} else if let Ok(mut token) =
token.extract::<PyRefMut<PyAddedToken>>()
{
token.is_special_token = true;
Ok(token.get_token())
} else {
Err(exceptions::PyTypeError::new_err(
"special_tokens must be a List[Union[str, AddedToken]]",
))
}
})
.collect::<PyResult<Vec<_>>>()?
builder.special_tokens(
val.cast_as::<PyList>()?
.into_iter()
.map(|token| {
if let Ok(content) = token.extract::<String>() {
Ok(PyAddedToken::from(content, Some(true)).get_token())
} else if let Ok(mut token) =
token.extract::<PyRefMut<PyAddedToken>>()
{
token.is_special_token = true;
Ok(token.get_token())
} else {
Err(exceptions::PyTypeError::new_err(
"special_tokens must be a List[Union[str, AddedToken]]",
))
}
})
.collect::<PyResult<Vec<_>>>()?,
);
}
_ => println!("Ignored unknown kwargs option {}", key),
}
@ -337,7 +344,12 @@ impl PyWordLevelTrainer {
Ok((
PyWordLevelTrainer {},
PyTrainer::new(Arc::new(trainer.into())),
PyTrainer::new(Arc::new(
builder
.build()
.expect("WordLevelTrainerBuilder cannot fail")
.into(),
)),
))
}
}

View File

@ -2,15 +2,20 @@ use super::WordLevel;
use crate::{AddedToken, Result, Trainer};
use std::collections::HashMap;
#[derive(Debug, Clone, Builder)]
pub struct WordLevelTrainer {
/// The minimum frequency a word must have to be part of the vocabulary
pub min_frequency: u32,
#[builder(default)]
min_frequency: u32,
/// The target vocabulary size
pub vocab_size: usize,
#[builder(default)]
vocab_size: usize,
/// Whether to show progress while training
pub show_progress: bool,
#[builder(default)]
show_progress: bool,
/// A list of special tokens that the model should know of
pub special_tokens: Vec<AddedToken>,
#[builder(default)]
special_tokens: Vec<AddedToken>,
}
impl Default for WordLevelTrainer {
@ -25,6 +30,10 @@ impl Default for WordLevelTrainer {
}
impl WordLevelTrainer {
pub fn builder() -> WordLevelTrainerBuilder {
WordLevelTrainerBuilder::default()
}
fn train(
&self,
word_counts: HashMap<String, u32>,