Node - Trainers train the Model in-place

This commit is contained in:
Anthony MOI
2020-11-19 19:57:50 -05:00
committed by Anthony MOI
parent 387b8a1033
commit 7fc37a03e8
5 changed files with 69 additions and 48 deletions

View File

@ -94,7 +94,7 @@ describe("pipelineExample", () => {
let { Tokenizer } = require("tokenizers/bindings/tokenizer");
let { WordPiece } = require("tokenizers/bindings/models");
let bertTokenizer = new Tokenizer(WordPiece.empty());
let bertTokenizer = new Tokenizer(WordPiece.init({}, { unkToken: "[UNK]" }));
// END bert_setup_tokenizer
// START bert_setup_normalizer
let { sequenceNormalizer, lowercaseNormalizer, nfdNormalizer, stripAccentsNormalizer }
@ -120,20 +120,13 @@ describe("pipelineExample", () => {
// END bert_setup_processor
// START bert_train_tokenizer
let { wordPieceTrainer } = require("tokenizers/bindings/trainers");
let { promisify } = require("util");
let trainer = wordPieceTrainer({
vocabSize: 30522,
specialTokens: ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
});
let files = ["test", "train", "valid"].map(split => `data/wikitext-103-raw/wiki.${split}.raw`);
bertTokenizer.train(trainer, files);
let modelFiles = bertTokenizer.getModel().save("data", "bert-wiki");
let fromFile = promisify(WordPiece.fromFile);
bertTokenizer.setModel(await fromFile(modelFiles[0], {
unkToken: "[UNK]"
}));
bertTokenizer.train(files, trainer);
bertTokenizer.save("data/bert-wiki.json")
// END bert_train_tokenizer

View File

@ -16,7 +16,7 @@ describe("quicktourExample", () => {
let { Tokenizer } = require("tokenizers/bindings/tokenizer");
let { BPE } = require("tokenizers/bindings/models");
let tokenizer = new Tokenizer(BPE.empty());
let tokenizer = new Tokenizer(BPE.init({}, [], { unkToken: "[UNK]" }));
// END init_tokenizer
// START init_trainer
let { bpeTrainer } = require("tokenizers/bindings/trainers");
@ -32,17 +32,8 @@ describe("quicktourExample", () => {
// END init_pretok
// START train
let files = ["test", "train", "valid"].map(split => `data/wikitext-103-raw/wiki.${split}.raw`);
tokenizer.train(trainer, files);
tokenizer.train(files, trainer);
// END train
// START reload_model
let { promisify } = require("util");
let modelFiles = tokenizer.getModel().save("data", "wiki");
let fromFile = promisify(BPE.fromFile);
tokenizer.setModel(await fromFile(modelFiles[0], modelFiles[1], {
unkToken: "[UNK]"
}));
// END reload_model
// START save
tokenizer.save("data/tokenizer-wiki.json");
// END save

View File

@ -2,11 +2,12 @@ extern crate tokenizers as tk;
use crate::extraction::*;
use crate::tasks::models::{BPEFromFilesTask, WordLevelFromFilesTask, WordPieceFromFilesTask};
use crate::trainers::Trainer;
use neon::prelude::*;
use std::collections::HashMap;
use std::path::Path;
use std::path::PathBuf;
use std::sync::Arc;
use std::sync::{Arc, RwLock};
use tk::models::{
bpe::{BpeBuilder, Merges, Vocab},
@ -21,37 +22,46 @@ use tk::Token;
#[derive(Clone, Serialize, Deserialize)]
pub struct Model {
#[serde(flatten)]
pub model: Option<Arc<ModelWrapper>>,
pub model: Option<Arc<RwLock<ModelWrapper>>>,
}
impl From<ModelWrapper> for Model {
fn from(wrapper: ModelWrapper) -> Self {
impl<M> From<M> for Model
where
M: Into<ModelWrapper>,
{
fn from(wrapper: M) -> Self {
Self {
model: Some(Arc::new(wrapper)),
model: Some(Arc::new(RwLock::new(wrapper.into()))),
}
}
}
impl tk::Model for Model {
type Trainer = Trainer;
fn tokenize(&self, sequence: &str) -> tk::Result<Vec<Token>> {
self.model
.as_ref()
.ok_or("Uninitialized Model")?
.read()
.unwrap()
.tokenize(sequence)
}
fn token_to_id(&self, token: &str) -> Option<u32> {
self.model.as_ref()?.token_to_id(token)
self.model.as_ref()?.read().unwrap().token_to_id(token)
}
fn id_to_token(&self, id: u32) -> Option<&str> {
self.model.as_ref()?.id_to_token(id)
fn id_to_token(&self, id: u32) -> Option<String> {
self.model.as_ref()?.read().unwrap().id_to_token(id)
}
fn get_vocab(&self) -> &HashMap<String, u32> {
fn get_vocab(&self) -> HashMap<String, u32> {
self.model
.as_ref()
.expect("Uninitialized Model")
.read()
.unwrap()
.get_vocab()
}
@ -59,6 +69,8 @@ impl tk::Model for Model {
self.model
.as_ref()
.expect("Uninitialized Model")
.read()
.unwrap()
.get_vocab_size()
}
@ -66,8 +78,20 @@ impl tk::Model for Model {
self.model
.as_ref()
.ok_or("Uninitialized Model")?
.read()
.unwrap()
.save(folder, name)
}
fn get_trainer(&self) -> Self::Trainer {
self.model
.as_ref()
.expect("Uninitialized Model")
.read()
.unwrap()
.get_trainer()
.into()
}
}
declare_types! {
@ -86,7 +110,8 @@ declare_types! {
let guard = cx.lock();
let files = this.borrow(&guard)
.model.as_ref().unwrap()
.model.as_ref().expect("Uninitialized Model")
.read().unwrap()
.save(
Path::new(&folder),
name.as_deref()
@ -153,7 +178,7 @@ fn bpe_init(mut cx: FunctionContext) -> JsResult<JsModel> {
let mut js_model = JsModel::new::<_, JsModel, _>(&mut cx, vec![])?;
let guard = cx.lock();
js_model.borrow_mut(&guard).model = Some(Arc::new(model.into()));
js_model.borrow_mut(&guard).model = Some(Arc::new(RwLock::new(model.into())));
Ok(js_model)
}
@ -191,7 +216,7 @@ fn bpe_empty(mut cx: FunctionContext) -> JsResult<JsModel> {
let bpe = tk::models::bpe::BPE::default();
let guard = cx.lock();
model.borrow_mut(&guard).model = Some(Arc::new(bpe.into()));
model.borrow_mut(&guard).model = Some(Arc::new(RwLock::new(bpe.into())));
Ok(model)
}
@ -236,7 +261,7 @@ fn wordpiece_init(mut cx: FunctionContext) -> JsResult<JsModel> {
let mut js_model = JsModel::new::<_, JsModel, _>(&mut cx, vec![])?;
let guard = cx.lock();
js_model.borrow_mut(&guard).model = Some(Arc::new(model.into()));
js_model.borrow_mut(&guard).model = Some(Arc::new(RwLock::new(model.into())));
Ok(js_model)
}
@ -270,7 +295,7 @@ fn wordpiece_empty(mut cx: FunctionContext) -> JsResult<JsModel> {
let wordpiece = tk::models::wordpiece::WordPiece::default();
let guard = cx.lock();
model.borrow_mut(&guard).model = Some(Arc::new(wordpiece.into()));
model.borrow_mut(&guard).model = Some(Arc::new(RwLock::new(wordpiece.into())));
Ok(model)
}
@ -305,7 +330,7 @@ fn wordlevel_init(mut cx: FunctionContext) -> JsResult<JsModel> {
let mut js_model = JsModel::new::<_, JsModel, _>(&mut cx, vec![])?;
let guard = cx.lock();
js_model.borrow_mut(&guard).model = Some(Arc::new(model.into()));
js_model.borrow_mut(&guard).model = Some(Arc::new(RwLock::new(model.into())));
Ok(js_model)
}
@ -337,7 +362,7 @@ fn wordlevel_empty(mut cx: FunctionContext) -> JsResult<JsModel> {
let wordlevel = tk::models::wordlevel::WordLevel::default();
let guard = cx.lock();
model.borrow_mut(&guard).model = Some(Arc::new(wordlevel.into()));
model.borrow_mut(&guard).model = Some(Arc::new(RwLock::new(wordlevel.into())));
Ok(model)
}
@ -362,7 +387,7 @@ fn unigram_init(mut cx: FunctionContext) -> JsResult<JsModel> {
let mut js_model = JsModel::new::<_, JsModel, _>(&mut cx, vec![])?;
let guard = cx.lock();
js_model.borrow_mut(&guard).model = Some(Arc::new(unigram.into()));
js_model.borrow_mut(&guard).model = Some(Arc::new(RwLock::new(unigram.into())));
Ok(js_model)
}
@ -373,7 +398,7 @@ fn unigram_empty(mut cx: FunctionContext) -> JsResult<JsModel> {
let unigram = tk::models::unigram::Unigram::default();
let guard = cx.lock();
model.borrow_mut(&guard).model = Some(Arc::new(unigram.into()));
model.borrow_mut(&guard).model = Some(Arc::new(RwLock::new(unigram.into())));
Ok(model)
}

View File

@ -2,7 +2,7 @@ extern crate tokenizers as tk;
use crate::models::*;
use neon::prelude::*;
use std::sync::Arc;
use std::sync::{Arc, RwLock};
use tk::models::bpe::{BpeBuilder, BPE};
use tk::models::wordlevel::{WordLevel, WordLevelBuilder};
use tk::models::wordpiece::{WordPiece, WordPieceBuilder};
@ -34,7 +34,7 @@ impl Task for WordPieceFromFilesTask {
let mut js_model = JsModel::new::<_, JsModel, _>(&mut cx, vec![])?;
let guard = cx.lock();
js_model.borrow_mut(&guard).model = Some(Arc::new(wordpiece.into()));
js_model.borrow_mut(&guard).model = Some(Arc::new(RwLock::new(wordpiece.into())));
Ok(js_model.upcast())
}
@ -67,7 +67,7 @@ impl Task for WordLevelFromFilesTask {
let mut js_model = JsModel::new::<_, JsModel, _>(&mut cx, vec![])?;
let guard = cx.lock();
js_model.borrow_mut(&guard).model = Some(Arc::new(wordlevel.into()));
js_model.borrow_mut(&guard).model = Some(Arc::new(RwLock::new(wordlevel.into())));
Ok(js_model.upcast())
}
@ -100,7 +100,7 @@ impl Task for BPEFromFilesTask {
let mut js_model = JsModel::new::<_, JsModel, _>(&mut cx, vec![])?;
let guard = cx.lock();
js_model.borrow_mut(&guard).model = Some(Arc::new(bpe.into()));
js_model.borrow_mut(&guard).model = Some(Arc::new(RwLock::new(bpe.into())));
Ok(js_model.upcast())
}

View File

@ -12,6 +12,7 @@ use crate::trainers::JsTrainer;
use neon::prelude::*;
use std::sync::{Arc, RwLock};
use tk::Model as ModelTrait;
use tk::TokenizerImpl;
// AddedToken
@ -634,7 +635,7 @@ declare_types! {
let guard = cx.lock();
let token = this.borrow(&guard)
.tokenizer.read().unwrap()
.id_to_token(id).map(|t| t.to_owned());
.id_to_token(id);
if let Some(token) = token {
Ok(cx.string(token).upcast())
@ -745,18 +746,29 @@ declare_types! {
}
method train(mut cx) {
// train(trainer: JsTrainer, files: string[])
// train(files: string[], trainer?: Trainer)
let trainer = cx.argument::<JsTrainer>(0)?;
let files = cx.extract::<Vec<String>>(1)?;
let files = cx.extract::<Vec<String>>(0)?;
let trainer = if let Some(val) = cx.argument_opt(1) {
let js_trainer = val.downcast::<JsTrainer>().or_throw(&mut cx)?;
let guard = cx.lock();
let trainer = js_trainer.borrow(&guard).clone();
trainer
} else {
let this = cx.this();
let guard = cx.lock();
let trainer = this.borrow(&guard).tokenizer.read().unwrap().get_model().get_trainer();
trainer
};
let mut this = cx.this();
let guard = cx.lock();
let trainer = trainer.borrow(&guard).clone();
this.borrow_mut(&guard)
.tokenizer.write().unwrap()
.train_and_replace(&trainer, files)
.train(&trainer, files)
.map_err(|e| Error(format!("{}", e)))?;
Ok(cx.undefined().upcast())