Node - Add Model.encodeBatch and make everything async

This commit is contained in:
Anthony MOI
2020-03-24 19:25:39 -04:00
parent eec74ca3e6
commit f79ae40d88
6 changed files with 260 additions and 79 deletions

View File

@ -251,6 +251,7 @@ version = "0.1.0"
dependencies = [
"neon 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)",
"neon-build 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)",
"rayon 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
"tokenizers 0.8.0",
]

View File

@ -15,4 +15,5 @@ neon-build = "0.3.3"
[dependencies]
neon = "0.3.3"
rayon = "1.2.0"
tokenizers = { path = "../../../tokenizers" }

View File

@ -1,6 +1,7 @@
#![warn(clippy::all)]
extern crate neon;
extern crate rayon;
extern crate tokenizers as tk;
mod container;

View File

@ -1,7 +1,6 @@
extern crate tokenizers as tk;
use crate::container::Container;
use crate::encoding::JsEncoding;
use crate::tasks::models::{BPEFromFilesTask, WordPieceFromFilesTask};
use neon::prelude::*;
use std::path::Path;
@ -52,83 +51,6 @@ declare_types! {
Err(e) => cx.throw_error(format!("{}", e))
}
}
method encode(mut cx) {
/// encode(sequence: (String | [String, [number, number]])[], typeId?: number = 0):
/// Encoding
let sequence = cx.argument::<JsArray>(0)?.to_vec(&mut cx)?;
let type_id = cx.argument_opt(1)
.map_or(Some(0), |arg| arg.downcast::<JsNumber>()
.ok()
.map(|h| h.value() as u32)
).unwrap();
enum Mode {
NoOffsets,
Offsets,
};
let mode = sequence.iter().next().map(|item| {
if item.downcast::<JsString>().is_ok() {
Ok(Mode::NoOffsets)
} else if item.downcast::<JsArray>().is_ok() {
Ok(Mode::Offsets)
} else {
Err("Input must be (String | [String, [number, number]])[]")
}
})
.unwrap()
.map_err(|e| cx.throw_error::<_, ()>(e.to_string()).unwrap_err())?;
let mut total_len = 0;
let sequence = sequence.iter().map(|item| match mode {
Mode::NoOffsets => {
let s = item.downcast::<JsString>().or_throw(&mut cx)?.value();
let len = s.chars().count();
total_len += len;
Ok((s, (total_len - len, total_len)))
},
Mode::Offsets => {
let tuple = item.downcast::<JsArray>().or_throw(&mut cx)?;
let s = tuple.get(&mut cx, 0)?
.downcast::<JsString>()
.or_throw(&mut cx)?
.value();
let offsets = tuple.get(&mut cx, 1)?
.downcast::<JsArray>()
.or_throw(&mut cx)?;
let (start, end) = (
offsets.get(&mut cx, 0)?
.downcast::<JsNumber>()
.or_throw(&mut cx)?
.value() as usize,
offsets.get(&mut cx, 1)?
.downcast::<JsNumber>().
or_throw(&mut cx)?
.value() as usize,
);
Ok((s, (start, end)))
}
}).collect::<Result<Vec<_>, _>>()?;
let encoding = {
let this = cx.this();
let guard = cx.lock();
let res = this.borrow(&guard)
.model
.execute(|model| model.unwrap()
.tokenize(sequence)
.map(|tokens| tk::tokenizer::Encoding::from_tokens(tokens, type_id))
);
res.map_err(|e| cx.throw_error::<_, ()>(format!("{}", e)).unwrap_err())?
};
let mut js_encoding = JsEncoding::new::<_, JsEncoding, _>(&mut cx, vec![])?;
let guard = cx.lock();
js_encoding.borrow_mut(&guard).encoding.to_owned(Box::new(encoding));
Ok(js_encoding.upcast())
}
}
}

View File

@ -2,6 +2,7 @@ extern crate tokenizers as tk;
use crate::encoding::*;
use neon::prelude::*;
use rayon::prelude::*;
use tk::tokenizer::{EncodeInput, Encoding, Tokenizer};
pub struct WorkingTokenizer {
@ -100,6 +101,97 @@ impl Task for EncodeTask {
}
}
pub enum EncodeTokenizedTask {
Single(WorkingTokenizer, Option<Vec<(String, (usize, usize))>>, u32),
Batch(
WorkingTokenizer,
Option<Vec<Vec<(String, (usize, usize))>>>,
u32,
),
}
pub enum EncodeTokenizedOutput {
Single(Encoding),
Batch(Vec<Encoding>),
}
impl Task for EncodeTokenizedTask {
type Output = EncodeTokenizedOutput;
type Error = String;
type JsEvent = JsValue;
fn perform(&self) -> Result<Self::Output, Self::Error> {
match self {
EncodeTokenizedTask::Single(worker, input, type_id) => {
let input = unsafe { std::ptr::replace(input as *const _ as *mut _, None) };
let tokenizer: &Tokenizer = unsafe { &*worker.ptr };
tokenizer
.get_model()
.tokenize(input.unwrap())
.map_err(|e| format!("{}", e))
.map(|tokens| {
EncodeTokenizedOutput::Single(Encoding::from_tokens(tokens, *type_id))
})
}
EncodeTokenizedTask::Batch(worker, input, type_id) => {
let input: Option<Vec<_>> =
unsafe { std::ptr::replace(input as *const _ as *mut _, None) };
let tokenizer: &Tokenizer = unsafe { &*worker.ptr };
input
.unwrap()
.into_par_iter()
.map(|input| {
tokenizer
.get_model()
.tokenize(input)
.map_err(|e| format!("{}", e))
.map(|tokens| Encoding::from_tokens(tokens, *type_id))
})
.collect::<Result<_, _>>()
.map(EncodeTokenizedOutput::Batch)
}
}
}
fn complete(
self,
mut cx: TaskContext,
result: Result<Self::Output, Self::Error>,
) -> JsResult<Self::JsEvent> {
match result.map_err(|e| cx.throw_error::<_, ()>(e).unwrap_err())? {
EncodeTokenizedOutput::Single(encoding) => {
let mut js_encoding = JsEncoding::new::<_, JsEncoding, _>(&mut cx, vec![])?;
// Set the actual encoding
let guard = cx.lock();
js_encoding
.borrow_mut(&guard)
.encoding
.to_owned(Box::new(encoding));
Ok(js_encoding.upcast())
}
EncodeTokenizedOutput::Batch(encodings) => {
let result = JsArray::new(&mut cx, encodings.len() as u32);
for (i, encoding) in encodings.into_iter().enumerate() {
let mut js_encoding = JsEncoding::new::<_, JsEncoding, _>(&mut cx, vec![])?;
// Set the actual encoding
let guard = cx.lock();
js_encoding
.borrow_mut(&guard)
.encoding
.to_owned(Box::new(encoding));
result.set(&mut cx, i as u32, js_encoding)?;
}
Ok(result.upcast())
}
}
}
}
pub enum DecodeTask {
Single(WorkingTokenizer, Vec<u32>, bool),
Batch(WorkingTokenizer, Vec<Vec<u32>>, bool),

View File

@ -6,7 +6,7 @@ use crate::models::JsModel;
use crate::normalizers::JsNormalizer;
use crate::pre_tokenizers::JsPreTokenizer;
use crate::processors::JsPostProcessor;
use crate::tasks::tokenizer::{DecodeTask, EncodeTask, WorkingTokenizer};
use crate::tasks::tokenizer::{DecodeTask, EncodeTask, EncodeTokenizedTask, WorkingTokenizer};
use crate::trainers::JsTrainer;
use neon::prelude::*;
@ -237,6 +237,170 @@ declare_types! {
Ok(cx.undefined().upcast())
}
method encodeTokenized(mut cx) {
/// encodeTokenized(
/// sequence: (String | [String, [number, number]])[],
/// typeId?: number = 0,
/// callback: (err, Encoding)
/// )
let sequence = cx.argument::<JsArray>(0)?.to_vec(&mut cx)?;
let type_id = cx.argument_opt(1)
.map_or(Some(0), |arg| arg.downcast::<JsNumber>()
.ok()
.map(|h| h.value() as u32)
).unwrap();
enum Mode {
NoOffsets,
Offsets,
};
let mode = sequence.iter().next().map(|item| {
if item.downcast::<JsString>().is_ok() {
Ok(Mode::NoOffsets)
} else if item.downcast::<JsArray>().is_ok() {
Ok(Mode::Offsets)
} else {
Err("Input must be (String | [String, [number, number]])[]")
}
})
.unwrap()
.map_err(|e| cx.throw_error::<_, ()>(e.to_string()).unwrap_err())?;
let mut total_len = 0;
let sequence = sequence.iter().map(|item| match mode {
Mode::NoOffsets => {
let s = item.downcast::<JsString>().or_throw(&mut cx)?.value();
let len = s.chars().count();
total_len += len;
Ok((s, (total_len - len, total_len)))
},
Mode::Offsets => {
let tuple = item.downcast::<JsArray>().or_throw(&mut cx)?;
let s = tuple.get(&mut cx, 0)?
.downcast::<JsString>()
.or_throw(&mut cx)?
.value();
let offsets = tuple.get(&mut cx, 1)?
.downcast::<JsArray>()
.or_throw(&mut cx)?;
let (start, end) = (
offsets.get(&mut cx, 0)?
.downcast::<JsNumber>()
.or_throw(&mut cx)?
.value() as usize,
offsets.get(&mut cx, 1)?
.downcast::<JsNumber>().
or_throw(&mut cx)?
.value() as usize,
);
Ok((s, (start, end)))
}
}).collect::<Result<Vec<_>, _>>()?;
let callback = cx.argument::<JsFunction>(2)?;
let worker = {
let this = cx.this();
let guard = cx.lock();
let worker = this.borrow(&guard).prepare_for_task();
worker
};
let task = EncodeTokenizedTask::Single(worker, Some(sequence), type_id);
task.schedule(callback);
Ok(cx.undefined().upcast())
}
method encodeTokenizedBatch(mut cx) {
/// encodeTokenizedBatch(
/// sequences: (String | [String, [number, number]])[][],
/// typeId?: number = 0,
/// callback: (err, Encoding)
/// )
let sequences = cx.argument::<JsArray>(0)?.to_vec(&mut cx)?;
let type_id = cx.argument_opt(1)
.map_or(Some(0), |arg| arg.downcast::<JsNumber>()
.ok()
.map(|h| h.value() as u32)
).unwrap();
enum Mode {
NoOffsets,
Offsets,
};
let mode = sequences.iter().next().map(|sequence| {
if let Ok(sequence) = sequence.downcast::<JsArray>().or_throw(&mut cx) {
sequence.to_vec(&mut cx).ok().map(|s| s.iter().next().map(|item| {
if item.downcast::<JsString>().is_ok() {
Some(Mode::NoOffsets)
} else if item.downcast::<JsArray>().is_ok() {
Some(Mode::Offsets)
} else {
None
}
}).flatten()).flatten()
} else {
None
}
})
.flatten()
.ok_or_else(||
cx.throw_error::<_, ()>(
"Input must be (String | [String, [number, number]])[]"
).unwrap_err()
)?;
let sequences = sequences.into_iter().map(|sequence| {
let mut total_len = 0;
sequence.downcast::<JsArray>().or_throw(&mut cx)?
.to_vec(&mut cx)?
.into_iter()
.map(|item| match mode {
Mode::NoOffsets => {
let s = item.downcast::<JsString>().or_throw(&mut cx)?.value();
let len = s.chars().count();
total_len += len;
Ok((s, (total_len - len, total_len)))
},
Mode::Offsets => {
let tuple = item.downcast::<JsArray>().or_throw(&mut cx)?;
let s = tuple.get(&mut cx, 0)?
.downcast::<JsString>()
.or_throw(&mut cx)?
.value();
let offsets = tuple.get(&mut cx, 1)?
.downcast::<JsArray>()
.or_throw(&mut cx)?;
let (start, end) = (
offsets.get(&mut cx, 0)?
.downcast::<JsNumber>()
.or_throw(&mut cx)?
.value() as usize,
offsets.get(&mut cx, 1)?
.downcast::<JsNumber>().
or_throw(&mut cx)?
.value() as usize,
);
Ok((s, (start, end)))
}
}).collect::<Result<Vec<_>, _>>()
})
.collect::<Result<Vec<_>, _>>()?;
let callback = cx.argument::<JsFunction>(2)?;
let worker = {
let this = cx.this();
let guard = cx.lock();
let worker = this.borrow(&guard).prepare_for_task();
worker
};
let task = EncodeTokenizedTask::Batch(worker, Some(sequences), type_id);
task.schedule(callback);
Ok(cx.undefined().upcast())
}
method decode(mut cx) {
// decode(ids: number[], skipSpecialTokens: bool, callback)