Node - Add encode & encodeBatch

This commit is contained in:
Anthony MOI
2020-01-09 11:47:35 -05:00
parent 1a54692190
commit 19d41a5810
4 changed files with 118 additions and 4 deletions

View File

@ -0,0 +1,20 @@
extern crate tokenizers as tk;
use crate::utils::Container;
use neon::prelude::*;
/// Encoding
pub struct Encoding {
pub encoding: Container<tk::tokenizer::Encoding>,
}
declare_types! {
pub class JsEncoding for Encoding {
init(_) {
// This should never be called from JavaScript
Ok(Encoding {
encoding: Container::Empty
})
}
}
}

View File

@ -2,6 +2,7 @@ extern crate neon;
extern crate tokenizers as tk;
mod decoders;
mod encoding;
mod models;
mod processors;
mod tokenizer;

View File

@ -30,7 +30,7 @@ fn bert_processing(mut cx: FunctionContext) -> JsResult<JsPostProcessor> {
sep.get(&mut cx, 0)?
.downcast::<JsString>()
.or_throw(&mut cx)?
.value() as String,
.value(),
sep.get(&mut cx, 1)?
.downcast::<JsNumber>()
.or_throw(&mut cx)?
@ -40,7 +40,7 @@ fn bert_processing(mut cx: FunctionContext) -> JsResult<JsPostProcessor> {
cls.get(&mut cx, 0)?
.downcast::<JsString>()
.or_throw(&mut cx)?
.value() as String,
.value(),
cls.get(&mut cx, 1)?
.downcast::<JsNumber>()
.or_throw(&mut cx)?

View File

@ -1,5 +1,6 @@
extern crate tokenizers as tk;
use crate::encoding::*;
use crate::models::*;
use neon::prelude::*;
@ -10,8 +11,8 @@ pub struct Tokenizer {
declare_types! {
pub class JsTokenizer for Tokenizer {
// Create
init(mut cx) {
// init(model: JsModel)
let mut model = cx.argument::<JsModel>(0)?;
if let Some(instance) = {
let guard = cx.lock();
@ -26,7 +27,22 @@ declare_types! {
}
}
method with_model(mut cx) {
method getVocabSize(mut cx) {
// getVocabSize(withAddedTokens: bool = true)
let mut with_added_tokens = true;
if let Some(args) = cx.argument_opt(0) {
with_added_tokens = args.downcast::<JsBoolean>().or_throw(&mut cx)?.value() as bool;
}
let mut this = cx.this();
let guard = cx.lock();
let size = this.borrow_mut(&guard).tokenizer.get_vocab_size(with_added_tokens);
Ok(cx.number(size as f64).upcast())
}
method withModel(mut cx) {
// with_model(model: JsModel)
let mut model = cx.argument::<JsModel>(0)?;
if let Some(instance) = {
let guard = cx.lock();
@ -45,6 +61,83 @@ declare_types! {
cx.throw_error("The Model is already being used in another Tokenizer")
}
}
method encode(mut cx) {
// encode(sentence: String, pair?: String): Encoding
let sentence = cx.argument::<JsString>(0)?.value();
let mut pair: Option<String> = None;
if let Some(args) = cx.argument_opt(1) {
pair = Some(args.downcast::<JsString>().or_throw(&mut cx)?.value());
}
let input = if let Some(pair) = pair {
tk::tokenizer::EncodeInput::Dual(sentence, pair)
} else {
tk::tokenizer::EncodeInput::Single(sentence)
};
let encoding = {
let this = cx.this();
let guard = cx.lock();
let res = this.borrow(&guard).tokenizer.encode(input);
res.map_err(|e| cx.throw_error::<_, ()>(format!("{}", e)).unwrap_err())?
};
let mut js_encoding = JsEncoding::new::<_, JsEncoding, _>(&mut cx, vec![])?;
// Set the actual encoding
let guard = cx.lock();
js_encoding.borrow_mut(&guard).encoding.to_owned(Box::new(encoding));
Ok(js_encoding.upcast())
}
method encodeBatch(mut cx) {
// type EncodeInput = (String | [String, String])[]
// encode_batch(sentences: EncodeInput[]): Encoding[]
let inputs = cx.argument::<JsArray>(0)?.to_vec(&mut cx)?;
let inputs = inputs.into_iter().map(|value| {
if let Ok(s) = value.downcast::<JsString>() {
Ok(tk::tokenizer::EncodeInput::Single(s.value()))
} else if let Ok(arr) = value.downcast::<JsArray>() {
if arr.len() != 2 {
cx.throw_error("Input must be an array of `String | [String, String]`")
} else {
Ok(tk::tokenizer::EncodeInput::Dual(
arr.get(&mut cx, 0)?
.downcast::<JsString>()
.or_throw(&mut cx)?
.value(),
arr.get(&mut cx, 1)?
.downcast::<JsString>()
.or_throw(&mut cx)?
.value())
)
}
} else {
cx.throw_error("Input must be an array of `String | [String, String]`")
}
}).collect::<NeonResult<Vec<_>>>()?;
let encodings = {
let this = cx.this();
let guard = cx.lock();
let res = this.borrow(&guard).tokenizer.encode_batch(inputs);
res.map_err(|e| cx.throw_error::<_, ()>(format!("{}", e)).unwrap_err())?
};
let result = JsArray::new(&mut cx, encodings.len() as u32);
for (i, encoding) in encodings.into_iter().enumerate() {
let mut js_encoding = JsEncoding::new::<_, JsEncoding, _>(&mut cx, vec![])?;
// Set the actual encoding
let guard = cx.lock();
js_encoding.borrow_mut(&guard).encoding.to_owned(Box::new(encoding));
result.set(&mut cx, i as u32, js_encoding)?;
}
Ok(result.upcast())
}
}
}