mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Node - Add encode & encodeBatch
This commit is contained in:
20
bindings/node/native/src/encoding.rs
Normal file
20
bindings/node/native/src/encoding.rs
Normal file
@ -0,0 +1,20 @@
|
||||
extern crate tokenizers as tk;
|
||||
|
||||
use crate::utils::Container;
|
||||
use neon::prelude::*;
|
||||
|
||||
/// Encoding
|
||||
pub struct Encoding {
|
||||
pub encoding: Container<tk::tokenizer::Encoding>,
|
||||
}
|
||||
|
||||
declare_types! {
|
||||
pub class JsEncoding for Encoding {
|
||||
init(_) {
|
||||
// This should never be called from JavaScript
|
||||
Ok(Encoding {
|
||||
encoding: Container::Empty
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
@ -2,6 +2,7 @@ extern crate neon;
|
||||
extern crate tokenizers as tk;
|
||||
|
||||
mod decoders;
|
||||
mod encoding;
|
||||
mod models;
|
||||
mod processors;
|
||||
mod tokenizer;
|
||||
|
@ -30,7 +30,7 @@ fn bert_processing(mut cx: FunctionContext) -> JsResult<JsPostProcessor> {
|
||||
sep.get(&mut cx, 0)?
|
||||
.downcast::<JsString>()
|
||||
.or_throw(&mut cx)?
|
||||
.value() as String,
|
||||
.value(),
|
||||
sep.get(&mut cx, 1)?
|
||||
.downcast::<JsNumber>()
|
||||
.or_throw(&mut cx)?
|
||||
@ -40,7 +40,7 @@ fn bert_processing(mut cx: FunctionContext) -> JsResult<JsPostProcessor> {
|
||||
cls.get(&mut cx, 0)?
|
||||
.downcast::<JsString>()
|
||||
.or_throw(&mut cx)?
|
||||
.value() as String,
|
||||
.value(),
|
||||
cls.get(&mut cx, 1)?
|
||||
.downcast::<JsNumber>()
|
||||
.or_throw(&mut cx)?
|
||||
|
@ -1,5 +1,6 @@
|
||||
extern crate tokenizers as tk;
|
||||
|
||||
use crate::encoding::*;
|
||||
use crate::models::*;
|
||||
use neon::prelude::*;
|
||||
|
||||
@ -10,8 +11,8 @@ pub struct Tokenizer {
|
||||
|
||||
declare_types! {
|
||||
pub class JsTokenizer for Tokenizer {
|
||||
// Create
|
||||
init(mut cx) {
|
||||
// init(model: JsModel)
|
||||
let mut model = cx.argument::<JsModel>(0)?;
|
||||
if let Some(instance) = {
|
||||
let guard = cx.lock();
|
||||
@ -26,7 +27,22 @@ declare_types! {
|
||||
}
|
||||
}
|
||||
|
||||
method with_model(mut cx) {
|
||||
method getVocabSize(mut cx) {
|
||||
// getVocabSize(withAddedTokens: bool = true)
|
||||
let mut with_added_tokens = true;
|
||||
if let Some(args) = cx.argument_opt(0) {
|
||||
with_added_tokens = args.downcast::<JsBoolean>().or_throw(&mut cx)?.value() as bool;
|
||||
}
|
||||
|
||||
let mut this = cx.this();
|
||||
let guard = cx.lock();
|
||||
let size = this.borrow_mut(&guard).tokenizer.get_vocab_size(with_added_tokens);
|
||||
|
||||
Ok(cx.number(size as f64).upcast())
|
||||
}
|
||||
|
||||
method withModel(mut cx) {
|
||||
// with_model(model: JsModel)
|
||||
let mut model = cx.argument::<JsModel>(0)?;
|
||||
if let Some(instance) = {
|
||||
let guard = cx.lock();
|
||||
@ -45,6 +61,83 @@ declare_types! {
|
||||
cx.throw_error("The Model is already being used in another Tokenizer")
|
||||
}
|
||||
}
|
||||
|
||||
method encode(mut cx) {
|
||||
// encode(sentence: String, pair?: String): Encoding
|
||||
let sentence = cx.argument::<JsString>(0)?.value();
|
||||
let mut pair: Option<String> = None;
|
||||
if let Some(args) = cx.argument_opt(1) {
|
||||
pair = Some(args.downcast::<JsString>().or_throw(&mut cx)?.value());
|
||||
}
|
||||
|
||||
let input = if let Some(pair) = pair {
|
||||
tk::tokenizer::EncodeInput::Dual(sentence, pair)
|
||||
} else {
|
||||
tk::tokenizer::EncodeInput::Single(sentence)
|
||||
};
|
||||
|
||||
let encoding = {
|
||||
let this = cx.this();
|
||||
let guard = cx.lock();
|
||||
let res = this.borrow(&guard).tokenizer.encode(input);
|
||||
res.map_err(|e| cx.throw_error::<_, ()>(format!("{}", e)).unwrap_err())?
|
||||
};
|
||||
|
||||
let mut js_encoding = JsEncoding::new::<_, JsEncoding, _>(&mut cx, vec![])?;
|
||||
// Set the actual encoding
|
||||
let guard = cx.lock();
|
||||
js_encoding.borrow_mut(&guard).encoding.to_owned(Box::new(encoding));
|
||||
|
||||
Ok(js_encoding.upcast())
|
||||
}
|
||||
|
||||
method encodeBatch(mut cx) {
|
||||
// type EncodeInput = (String | [String, String])[]
|
||||
// encode_batch(sentences: EncodeInput[]): Encoding[]
|
||||
let inputs = cx.argument::<JsArray>(0)?.to_vec(&mut cx)?;
|
||||
let inputs = inputs.into_iter().map(|value| {
|
||||
if let Ok(s) = value.downcast::<JsString>() {
|
||||
Ok(tk::tokenizer::EncodeInput::Single(s.value()))
|
||||
} else if let Ok(arr) = value.downcast::<JsArray>() {
|
||||
if arr.len() != 2 {
|
||||
cx.throw_error("Input must be an array of `String | [String, String]`")
|
||||
} else {
|
||||
Ok(tk::tokenizer::EncodeInput::Dual(
|
||||
arr.get(&mut cx, 0)?
|
||||
.downcast::<JsString>()
|
||||
.or_throw(&mut cx)?
|
||||
.value(),
|
||||
arr.get(&mut cx, 1)?
|
||||
.downcast::<JsString>()
|
||||
.or_throw(&mut cx)?
|
||||
.value())
|
||||
)
|
||||
}
|
||||
} else {
|
||||
cx.throw_error("Input must be an array of `String | [String, String]`")
|
||||
}
|
||||
}).collect::<NeonResult<Vec<_>>>()?;
|
||||
|
||||
let encodings = {
|
||||
let this = cx.this();
|
||||
let guard = cx.lock();
|
||||
let res = this.borrow(&guard).tokenizer.encode_batch(inputs);
|
||||
res.map_err(|e| cx.throw_error::<_, ()>(format!("{}", e)).unwrap_err())?
|
||||
};
|
||||
|
||||
let result = JsArray::new(&mut cx, encodings.len() as u32);
|
||||
for (i, encoding) in encodings.into_iter().enumerate() {
|
||||
let mut js_encoding = JsEncoding::new::<_, JsEncoding, _>(&mut cx, vec![])?;
|
||||
|
||||
// Set the actual encoding
|
||||
let guard = cx.lock();
|
||||
js_encoding.borrow_mut(&guard).encoding.to_owned(Box::new(encoding));
|
||||
|
||||
result.set(&mut cx, i as u32, js_encoding)?;
|
||||
}
|
||||
|
||||
Ok(result.upcast())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user