From 7fc37a03e87963a21f7b27c82ef0712d58fa38bc Mon Sep 17 00:00:00 2001 From: Anthony MOI Date: Thu, 19 Nov 2020 19:57:50 -0500 Subject: [PATCH] Node - Trainers train the Model in-place --- .../examples/documentation/pipeline.test.ts | 11 +--- .../examples/documentation/quicktour.test.ts | 13 +--- bindings/node/native/src/models.rs | 61 +++++++++++++------ bindings/node/native/src/tasks/models.rs | 8 +-- bindings/node/native/src/tokenizer.rs | 24 ++++++-- 5 files changed, 69 insertions(+), 48 deletions(-) diff --git a/bindings/node/examples/documentation/pipeline.test.ts b/bindings/node/examples/documentation/pipeline.test.ts index c7d963e3..f3623710 100644 --- a/bindings/node/examples/documentation/pipeline.test.ts +++ b/bindings/node/examples/documentation/pipeline.test.ts @@ -94,7 +94,7 @@ describe("pipelineExample", () => { let { Tokenizer } = require("tokenizers/bindings/tokenizer"); let { WordPiece } = require("tokenizers/bindings/models"); - let bertTokenizer = new Tokenizer(WordPiece.empty()); + let bertTokenizer = new Tokenizer(WordPiece.init({}, { unkToken: "[UNK]" })); // END bert_setup_tokenizer // START bert_setup_normalizer let { sequenceNormalizer, lowercaseNormalizer, nfdNormalizer, stripAccentsNormalizer } @@ -120,20 +120,13 @@ describe("pipelineExample", () => { // END bert_setup_processor // START bert_train_tokenizer let { wordPieceTrainer } = require("tokenizers/bindings/trainers"); - let { promisify } = require("util"); let trainer = wordPieceTrainer({ vocabSize: 30522, specialTokens: ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"] }); let files = ["test", "train", "valid"].map(split => `data/wikitext-103-raw/wiki.${split}.raw`); - bertTokenizer.train(trainer, files); - - let modelFiles = bertTokenizer.getModel().save("data", "bert-wiki"); - let fromFile = promisify(WordPiece.fromFile); - bertTokenizer.setModel(await fromFile(modelFiles[0], { - unkToken: "[UNK]" - })); + bertTokenizer.train(files, trainer); bertTokenizer.save("data/bert-wiki.json") // END bert_train_tokenizer diff --git a/bindings/node/examples/documentation/quicktour.test.ts b/bindings/node/examples/documentation/quicktour.test.ts index f144efec..a91964b0 100644 --- a/bindings/node/examples/documentation/quicktour.test.ts +++ b/bindings/node/examples/documentation/quicktour.test.ts @@ -16,7 +16,7 @@ describe("quicktourExample", () => { let { Tokenizer } = require("tokenizers/bindings/tokenizer"); let { BPE } = require("tokenizers/bindings/models"); - let tokenizer = new Tokenizer(BPE.empty()); + let tokenizer = new Tokenizer(BPE.init({}, [], { unkToken: "[UNK]" })); // END init_tokenizer // START init_trainer let { bpeTrainer } = require("tokenizers/bindings/trainers"); @@ -32,17 +32,8 @@ describe("quicktourExample", () => { // END init_pretok // START train let files = ["test", "train", "valid"].map(split => `data/wikitext-103-raw/wiki.${split}.raw`); - tokenizer.train(trainer, files); + tokenizer.train(files, trainer); // END train - // START reload_model - let { promisify } = require("util"); - - let modelFiles = tokenizer.getModel().save("data", "wiki"); - let fromFile = promisify(BPE.fromFile); - tokenizer.setModel(await fromFile(modelFiles[0], modelFiles[1], { - unkToken: "[UNK]" - })); - // END reload_model // START save tokenizer.save("data/tokenizer-wiki.json"); // END save diff --git a/bindings/node/native/src/models.rs b/bindings/node/native/src/models.rs index c1e45dbb..a64187e7 100644 --- a/bindings/node/native/src/models.rs +++ b/bindings/node/native/src/models.rs @@ -2,11 +2,12 @@ extern crate tokenizers as tk; use crate::extraction::*; use crate::tasks::models::{BPEFromFilesTask, WordLevelFromFilesTask, WordPieceFromFilesTask}; +use crate::trainers::Trainer; use neon::prelude::*; use std::collections::HashMap; use std::path::Path; use std::path::PathBuf; -use std::sync::Arc; +use std::sync::{Arc, RwLock}; use tk::models::{ bpe::{BpeBuilder, Merges, Vocab}, @@ -21,37 +22,46 @@ use tk::Token; #[derive(Clone, Serialize, Deserialize)] pub struct Model { #[serde(flatten)] - pub model: Option>, + pub model: Option>>, } -impl From for Model { - fn from(wrapper: ModelWrapper) -> Self { +impl From for Model +where + M: Into, +{ + fn from(wrapper: M) -> Self { Self { - model: Some(Arc::new(wrapper)), + model: Some(Arc::new(RwLock::new(wrapper.into()))), } } } impl tk::Model for Model { + type Trainer = Trainer; + fn tokenize(&self, sequence: &str) -> tk::Result> { self.model .as_ref() .ok_or("Uninitialized Model")? + .read() + .unwrap() .tokenize(sequence) } fn token_to_id(&self, token: &str) -> Option { - self.model.as_ref()?.token_to_id(token) + self.model.as_ref()?.read().unwrap().token_to_id(token) } - fn id_to_token(&self, id: u32) -> Option<&str> { - self.model.as_ref()?.id_to_token(id) + fn id_to_token(&self, id: u32) -> Option { + self.model.as_ref()?.read().unwrap().id_to_token(id) } - fn get_vocab(&self) -> &HashMap { + fn get_vocab(&self) -> HashMap { self.model .as_ref() .expect("Uninitialized Model") + .read() + .unwrap() .get_vocab() } @@ -59,6 +69,8 @@ impl tk::Model for Model { self.model .as_ref() .expect("Uninitialized Model") + .read() + .unwrap() .get_vocab_size() } @@ -66,8 +78,20 @@ impl tk::Model for Model { self.model .as_ref() .ok_or("Uninitialized Model")? + .read() + .unwrap() .save(folder, name) } + + fn get_trainer(&self) -> Self::Trainer { + self.model + .as_ref() + .expect("Uninitialized Model") + .read() + .unwrap() + .get_trainer() + .into() + } } declare_types! { @@ -86,7 +110,8 @@ declare_types! { let guard = cx.lock(); let files = this.borrow(&guard) - .model.as_ref().unwrap() + .model.as_ref().expect("Uninitialized Model") + .read().unwrap() .save( Path::new(&folder), name.as_deref() @@ -153,7 +178,7 @@ fn bpe_init(mut cx: FunctionContext) -> JsResult { let mut js_model = JsModel::new::<_, JsModel, _>(&mut cx, vec![])?; let guard = cx.lock(); - js_model.borrow_mut(&guard).model = Some(Arc::new(model.into())); + js_model.borrow_mut(&guard).model = Some(Arc::new(RwLock::new(model.into()))); Ok(js_model) } @@ -191,7 +216,7 @@ fn bpe_empty(mut cx: FunctionContext) -> JsResult { let bpe = tk::models::bpe::BPE::default(); let guard = cx.lock(); - model.borrow_mut(&guard).model = Some(Arc::new(bpe.into())); + model.borrow_mut(&guard).model = Some(Arc::new(RwLock::new(bpe.into()))); Ok(model) } @@ -236,7 +261,7 @@ fn wordpiece_init(mut cx: FunctionContext) -> JsResult { let mut js_model = JsModel::new::<_, JsModel, _>(&mut cx, vec![])?; let guard = cx.lock(); - js_model.borrow_mut(&guard).model = Some(Arc::new(model.into())); + js_model.borrow_mut(&guard).model = Some(Arc::new(RwLock::new(model.into()))); Ok(js_model) } @@ -270,7 +295,7 @@ fn wordpiece_empty(mut cx: FunctionContext) -> JsResult { let wordpiece = tk::models::wordpiece::WordPiece::default(); let guard = cx.lock(); - model.borrow_mut(&guard).model = Some(Arc::new(wordpiece.into())); + model.borrow_mut(&guard).model = Some(Arc::new(RwLock::new(wordpiece.into()))); Ok(model) } @@ -305,7 +330,7 @@ fn wordlevel_init(mut cx: FunctionContext) -> JsResult { let mut js_model = JsModel::new::<_, JsModel, _>(&mut cx, vec![])?; let guard = cx.lock(); - js_model.borrow_mut(&guard).model = Some(Arc::new(model.into())); + js_model.borrow_mut(&guard).model = Some(Arc::new(RwLock::new(model.into()))); Ok(js_model) } @@ -337,7 +362,7 @@ fn wordlevel_empty(mut cx: FunctionContext) -> JsResult { let wordlevel = tk::models::wordlevel::WordLevel::default(); let guard = cx.lock(); - model.borrow_mut(&guard).model = Some(Arc::new(wordlevel.into())); + model.borrow_mut(&guard).model = Some(Arc::new(RwLock::new(wordlevel.into()))); Ok(model) } @@ -362,7 +387,7 @@ fn unigram_init(mut cx: FunctionContext) -> JsResult { let mut js_model = JsModel::new::<_, JsModel, _>(&mut cx, vec![])?; let guard = cx.lock(); - js_model.borrow_mut(&guard).model = Some(Arc::new(unigram.into())); + js_model.borrow_mut(&guard).model = Some(Arc::new(RwLock::new(unigram.into()))); Ok(js_model) } @@ -373,7 +398,7 @@ fn unigram_empty(mut cx: FunctionContext) -> JsResult { let unigram = tk::models::unigram::Unigram::default(); let guard = cx.lock(); - model.borrow_mut(&guard).model = Some(Arc::new(unigram.into())); + model.borrow_mut(&guard).model = Some(Arc::new(RwLock::new(unigram.into()))); Ok(model) } diff --git a/bindings/node/native/src/tasks/models.rs b/bindings/node/native/src/tasks/models.rs index cc266e28..42c3fced 100644 --- a/bindings/node/native/src/tasks/models.rs +++ b/bindings/node/native/src/tasks/models.rs @@ -2,7 +2,7 @@ extern crate tokenizers as tk; use crate::models::*; use neon::prelude::*; -use std::sync::Arc; +use std::sync::{Arc, RwLock}; use tk::models::bpe::{BpeBuilder, BPE}; use tk::models::wordlevel::{WordLevel, WordLevelBuilder}; use tk::models::wordpiece::{WordPiece, WordPieceBuilder}; @@ -34,7 +34,7 @@ impl Task for WordPieceFromFilesTask { let mut js_model = JsModel::new::<_, JsModel, _>(&mut cx, vec![])?; let guard = cx.lock(); - js_model.borrow_mut(&guard).model = Some(Arc::new(wordpiece.into())); + js_model.borrow_mut(&guard).model = Some(Arc::new(RwLock::new(wordpiece.into()))); Ok(js_model.upcast()) } @@ -67,7 +67,7 @@ impl Task for WordLevelFromFilesTask { let mut js_model = JsModel::new::<_, JsModel, _>(&mut cx, vec![])?; let guard = cx.lock(); - js_model.borrow_mut(&guard).model = Some(Arc::new(wordlevel.into())); + js_model.borrow_mut(&guard).model = Some(Arc::new(RwLock::new(wordlevel.into()))); Ok(js_model.upcast()) } @@ -100,7 +100,7 @@ impl Task for BPEFromFilesTask { let mut js_model = JsModel::new::<_, JsModel, _>(&mut cx, vec![])?; let guard = cx.lock(); - js_model.borrow_mut(&guard).model = Some(Arc::new(bpe.into())); + js_model.borrow_mut(&guard).model = Some(Arc::new(RwLock::new(bpe.into()))); Ok(js_model.upcast()) } diff --git a/bindings/node/native/src/tokenizer.rs b/bindings/node/native/src/tokenizer.rs index 50d5054b..0d526495 100644 --- a/bindings/node/native/src/tokenizer.rs +++ b/bindings/node/native/src/tokenizer.rs @@ -12,6 +12,7 @@ use crate::trainers::JsTrainer; use neon::prelude::*; use std::sync::{Arc, RwLock}; +use tk::Model as ModelTrait; use tk::TokenizerImpl; // AddedToken @@ -634,7 +635,7 @@ declare_types! { let guard = cx.lock(); let token = this.borrow(&guard) .tokenizer.read().unwrap() - .id_to_token(id).map(|t| t.to_owned()); + .id_to_token(id); if let Some(token) = token { Ok(cx.string(token).upcast()) @@ -745,18 +746,29 @@ declare_types! { } method train(mut cx) { - // train(trainer: JsTrainer, files: string[]) + // train(files: string[], trainer?: Trainer) - let trainer = cx.argument::(0)?; - let files = cx.extract::>(1)?; + let files = cx.extract::>(0)?; + let trainer = if let Some(val) = cx.argument_opt(1) { + let js_trainer = val.downcast::().or_throw(&mut cx)?; + let guard = cx.lock(); + + let trainer = js_trainer.borrow(&guard).clone(); + trainer + } else { + let this = cx.this(); + let guard = cx.lock(); + + let trainer = this.borrow(&guard).tokenizer.read().unwrap().get_model().get_trainer(); + trainer + }; let mut this = cx.this(); let guard = cx.lock(); - let trainer = trainer.borrow(&guard).clone(); this.borrow_mut(&guard) .tokenizer.write().unwrap() - .train_and_replace(&trainer, files) + .train(&trainer, files) .map_err(|e| Error(format!("{}", e)))?; Ok(cx.undefined().upcast())