mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Node - Add Model.encodeBatch and make everything async
This commit is contained in:
1
bindings/node/native/Cargo.lock
generated
1
bindings/node/native/Cargo.lock
generated
@ -251,6 +251,7 @@ version = "0.1.0"
|
||||
dependencies = [
|
||||
"neon 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"neon-build 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"rayon 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"tokenizers 0.8.0",
|
||||
]
|
||||
|
||||
|
@ -15,4 +15,5 @@ neon-build = "0.3.3"
|
||||
|
||||
[dependencies]
|
||||
neon = "0.3.3"
|
||||
rayon = "1.2.0"
|
||||
tokenizers = { path = "../../../tokenizers" }
|
||||
|
@ -1,6 +1,7 @@
|
||||
#![warn(clippy::all)]
|
||||
|
||||
extern crate neon;
|
||||
extern crate rayon;
|
||||
extern crate tokenizers as tk;
|
||||
|
||||
mod container;
|
||||
|
@ -1,7 +1,6 @@
|
||||
extern crate tokenizers as tk;
|
||||
|
||||
use crate::container::Container;
|
||||
use crate::encoding::JsEncoding;
|
||||
use crate::tasks::models::{BPEFromFilesTask, WordPieceFromFilesTask};
|
||||
use neon::prelude::*;
|
||||
use std::path::Path;
|
||||
@ -52,83 +51,6 @@ declare_types! {
|
||||
Err(e) => cx.throw_error(format!("{}", e))
|
||||
}
|
||||
}
|
||||
|
||||
method encode(mut cx) {
|
||||
/// encode(sequence: (String | [String, [number, number]])[], typeId?: number = 0):
|
||||
/// Encoding
|
||||
let sequence = cx.argument::<JsArray>(0)?.to_vec(&mut cx)?;
|
||||
let type_id = cx.argument_opt(1)
|
||||
.map_or(Some(0), |arg| arg.downcast::<JsNumber>()
|
||||
.ok()
|
||||
.map(|h| h.value() as u32)
|
||||
).unwrap();
|
||||
|
||||
enum Mode {
|
||||
NoOffsets,
|
||||
Offsets,
|
||||
};
|
||||
let mode = sequence.iter().next().map(|item| {
|
||||
if item.downcast::<JsString>().is_ok() {
|
||||
Ok(Mode::NoOffsets)
|
||||
} else if item.downcast::<JsArray>().is_ok() {
|
||||
Ok(Mode::Offsets)
|
||||
} else {
|
||||
Err("Input must be (String | [String, [number, number]])[]")
|
||||
}
|
||||
})
|
||||
.unwrap()
|
||||
.map_err(|e| cx.throw_error::<_, ()>(e.to_string()).unwrap_err())?;
|
||||
|
||||
let mut total_len = 0;
|
||||
let sequence = sequence.iter().map(|item| match mode {
|
||||
Mode::NoOffsets => {
|
||||
let s = item.downcast::<JsString>().or_throw(&mut cx)?.value();
|
||||
let len = s.chars().count();
|
||||
total_len += len;
|
||||
Ok((s, (total_len - len, total_len)))
|
||||
},
|
||||
Mode::Offsets => {
|
||||
let tuple = item.downcast::<JsArray>().or_throw(&mut cx)?;
|
||||
let s = tuple.get(&mut cx, 0)?
|
||||
.downcast::<JsString>()
|
||||
.or_throw(&mut cx)?
|
||||
.value();
|
||||
let offsets = tuple.get(&mut cx, 1)?
|
||||
.downcast::<JsArray>()
|
||||
.or_throw(&mut cx)?;
|
||||
let (start, end) = (
|
||||
offsets.get(&mut cx, 0)?
|
||||
.downcast::<JsNumber>()
|
||||
.or_throw(&mut cx)?
|
||||
.value() as usize,
|
||||
offsets.get(&mut cx, 1)?
|
||||
.downcast::<JsNumber>().
|
||||
or_throw(&mut cx)?
|
||||
.value() as usize,
|
||||
);
|
||||
Ok((s, (start, end)))
|
||||
}
|
||||
}).collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
let encoding = {
|
||||
let this = cx.this();
|
||||
let guard = cx.lock();
|
||||
let res = this.borrow(&guard)
|
||||
.model
|
||||
.execute(|model| model.unwrap()
|
||||
.tokenize(sequence)
|
||||
.map(|tokens| tk::tokenizer::Encoding::from_tokens(tokens, type_id))
|
||||
);
|
||||
res.map_err(|e| cx.throw_error::<_, ()>(format!("{}", e)).unwrap_err())?
|
||||
};
|
||||
|
||||
let mut js_encoding = JsEncoding::new::<_, JsEncoding, _>(&mut cx, vec![])?;
|
||||
|
||||
let guard = cx.lock();
|
||||
js_encoding.borrow_mut(&guard).encoding.to_owned(Box::new(encoding));
|
||||
|
||||
Ok(js_encoding.upcast())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2,6 +2,7 @@ extern crate tokenizers as tk;
|
||||
|
||||
use crate::encoding::*;
|
||||
use neon::prelude::*;
|
||||
use rayon::prelude::*;
|
||||
use tk::tokenizer::{EncodeInput, Encoding, Tokenizer};
|
||||
|
||||
pub struct WorkingTokenizer {
|
||||
@ -100,6 +101,97 @@ impl Task for EncodeTask {
|
||||
}
|
||||
}
|
||||
|
||||
pub enum EncodeTokenizedTask {
|
||||
Single(WorkingTokenizer, Option<Vec<(String, (usize, usize))>>, u32),
|
||||
Batch(
|
||||
WorkingTokenizer,
|
||||
Option<Vec<Vec<(String, (usize, usize))>>>,
|
||||
u32,
|
||||
),
|
||||
}
|
||||
|
||||
pub enum EncodeTokenizedOutput {
|
||||
Single(Encoding),
|
||||
Batch(Vec<Encoding>),
|
||||
}
|
||||
|
||||
impl Task for EncodeTokenizedTask {
|
||||
type Output = EncodeTokenizedOutput;
|
||||
type Error = String;
|
||||
type JsEvent = JsValue;
|
||||
|
||||
fn perform(&self) -> Result<Self::Output, Self::Error> {
|
||||
match self {
|
||||
EncodeTokenizedTask::Single(worker, input, type_id) => {
|
||||
let input = unsafe { std::ptr::replace(input as *const _ as *mut _, None) };
|
||||
let tokenizer: &Tokenizer = unsafe { &*worker.ptr };
|
||||
|
||||
tokenizer
|
||||
.get_model()
|
||||
.tokenize(input.unwrap())
|
||||
.map_err(|e| format!("{}", e))
|
||||
.map(|tokens| {
|
||||
EncodeTokenizedOutput::Single(Encoding::from_tokens(tokens, *type_id))
|
||||
})
|
||||
}
|
||||
EncodeTokenizedTask::Batch(worker, input, type_id) => {
|
||||
let input: Option<Vec<_>> =
|
||||
unsafe { std::ptr::replace(input as *const _ as *mut _, None) };
|
||||
let tokenizer: &Tokenizer = unsafe { &*worker.ptr };
|
||||
|
||||
input
|
||||
.unwrap()
|
||||
.into_par_iter()
|
||||
.map(|input| {
|
||||
tokenizer
|
||||
.get_model()
|
||||
.tokenize(input)
|
||||
.map_err(|e| format!("{}", e))
|
||||
.map(|tokens| Encoding::from_tokens(tokens, *type_id))
|
||||
})
|
||||
.collect::<Result<_, _>>()
|
||||
.map(EncodeTokenizedOutput::Batch)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn complete(
|
||||
self,
|
||||
mut cx: TaskContext,
|
||||
result: Result<Self::Output, Self::Error>,
|
||||
) -> JsResult<Self::JsEvent> {
|
||||
match result.map_err(|e| cx.throw_error::<_, ()>(e).unwrap_err())? {
|
||||
EncodeTokenizedOutput::Single(encoding) => {
|
||||
let mut js_encoding = JsEncoding::new::<_, JsEncoding, _>(&mut cx, vec![])?;
|
||||
// Set the actual encoding
|
||||
let guard = cx.lock();
|
||||
js_encoding
|
||||
.borrow_mut(&guard)
|
||||
.encoding
|
||||
.to_owned(Box::new(encoding));
|
||||
|
||||
Ok(js_encoding.upcast())
|
||||
}
|
||||
EncodeTokenizedOutput::Batch(encodings) => {
|
||||
let result = JsArray::new(&mut cx, encodings.len() as u32);
|
||||
for (i, encoding) in encodings.into_iter().enumerate() {
|
||||
let mut js_encoding = JsEncoding::new::<_, JsEncoding, _>(&mut cx, vec![])?;
|
||||
|
||||
// Set the actual encoding
|
||||
let guard = cx.lock();
|
||||
js_encoding
|
||||
.borrow_mut(&guard)
|
||||
.encoding
|
||||
.to_owned(Box::new(encoding));
|
||||
|
||||
result.set(&mut cx, i as u32, js_encoding)?;
|
||||
}
|
||||
Ok(result.upcast())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub enum DecodeTask {
|
||||
Single(WorkingTokenizer, Vec<u32>, bool),
|
||||
Batch(WorkingTokenizer, Vec<Vec<u32>>, bool),
|
||||
|
@ -6,7 +6,7 @@ use crate::models::JsModel;
|
||||
use crate::normalizers::JsNormalizer;
|
||||
use crate::pre_tokenizers::JsPreTokenizer;
|
||||
use crate::processors::JsPostProcessor;
|
||||
use crate::tasks::tokenizer::{DecodeTask, EncodeTask, WorkingTokenizer};
|
||||
use crate::tasks::tokenizer::{DecodeTask, EncodeTask, EncodeTokenizedTask, WorkingTokenizer};
|
||||
use crate::trainers::JsTrainer;
|
||||
use neon::prelude::*;
|
||||
|
||||
@ -237,6 +237,170 @@ declare_types! {
|
||||
Ok(cx.undefined().upcast())
|
||||
}
|
||||
|
||||
method encodeTokenized(mut cx) {
|
||||
/// encodeTokenized(
|
||||
/// sequence: (String | [String, [number, number]])[],
|
||||
/// typeId?: number = 0,
|
||||
/// callback: (err, Encoding)
|
||||
/// )
|
||||
|
||||
let sequence = cx.argument::<JsArray>(0)?.to_vec(&mut cx)?;
|
||||
let type_id = cx.argument_opt(1)
|
||||
.map_or(Some(0), |arg| arg.downcast::<JsNumber>()
|
||||
.ok()
|
||||
.map(|h| h.value() as u32)
|
||||
).unwrap();
|
||||
|
||||
enum Mode {
|
||||
NoOffsets,
|
||||
Offsets,
|
||||
};
|
||||
let mode = sequence.iter().next().map(|item| {
|
||||
if item.downcast::<JsString>().is_ok() {
|
||||
Ok(Mode::NoOffsets)
|
||||
} else if item.downcast::<JsArray>().is_ok() {
|
||||
Ok(Mode::Offsets)
|
||||
} else {
|
||||
Err("Input must be (String | [String, [number, number]])[]")
|
||||
}
|
||||
})
|
||||
.unwrap()
|
||||
.map_err(|e| cx.throw_error::<_, ()>(e.to_string()).unwrap_err())?;
|
||||
|
||||
let mut total_len = 0;
|
||||
let sequence = sequence.iter().map(|item| match mode {
|
||||
Mode::NoOffsets => {
|
||||
let s = item.downcast::<JsString>().or_throw(&mut cx)?.value();
|
||||
let len = s.chars().count();
|
||||
total_len += len;
|
||||
Ok((s, (total_len - len, total_len)))
|
||||
},
|
||||
Mode::Offsets => {
|
||||
let tuple = item.downcast::<JsArray>().or_throw(&mut cx)?;
|
||||
let s = tuple.get(&mut cx, 0)?
|
||||
.downcast::<JsString>()
|
||||
.or_throw(&mut cx)?
|
||||
.value();
|
||||
let offsets = tuple.get(&mut cx, 1)?
|
||||
.downcast::<JsArray>()
|
||||
.or_throw(&mut cx)?;
|
||||
let (start, end) = (
|
||||
offsets.get(&mut cx, 0)?
|
||||
.downcast::<JsNumber>()
|
||||
.or_throw(&mut cx)?
|
||||
.value() as usize,
|
||||
offsets.get(&mut cx, 1)?
|
||||
.downcast::<JsNumber>().
|
||||
or_throw(&mut cx)?
|
||||
.value() as usize,
|
||||
);
|
||||
Ok((s, (start, end)))
|
||||
}
|
||||
}).collect::<Result<Vec<_>, _>>()?;
|
||||
let callback = cx.argument::<JsFunction>(2)?;
|
||||
|
||||
let worker = {
|
||||
let this = cx.this();
|
||||
let guard = cx.lock();
|
||||
let worker = this.borrow(&guard).prepare_for_task();
|
||||
worker
|
||||
};
|
||||
|
||||
let task = EncodeTokenizedTask::Single(worker, Some(sequence), type_id);
|
||||
task.schedule(callback);
|
||||
Ok(cx.undefined().upcast())
|
||||
}
|
||||
|
||||
method encodeTokenizedBatch(mut cx) {
|
||||
/// encodeTokenizedBatch(
|
||||
/// sequences: (String | [String, [number, number]])[][],
|
||||
/// typeId?: number = 0,
|
||||
/// callback: (err, Encoding)
|
||||
/// )
|
||||
|
||||
let sequences = cx.argument::<JsArray>(0)?.to_vec(&mut cx)?;
|
||||
let type_id = cx.argument_opt(1)
|
||||
.map_or(Some(0), |arg| arg.downcast::<JsNumber>()
|
||||
.ok()
|
||||
.map(|h| h.value() as u32)
|
||||
).unwrap();
|
||||
|
||||
enum Mode {
|
||||
NoOffsets,
|
||||
Offsets,
|
||||
};
|
||||
let mode = sequences.iter().next().map(|sequence| {
|
||||
if let Ok(sequence) = sequence.downcast::<JsArray>().or_throw(&mut cx) {
|
||||
sequence.to_vec(&mut cx).ok().map(|s| s.iter().next().map(|item| {
|
||||
if item.downcast::<JsString>().is_ok() {
|
||||
Some(Mode::NoOffsets)
|
||||
} else if item.downcast::<JsArray>().is_ok() {
|
||||
Some(Mode::Offsets)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}).flatten()).flatten()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.flatten()
|
||||
.ok_or_else(||
|
||||
cx.throw_error::<_, ()>(
|
||||
"Input must be (String | [String, [number, number]])[]"
|
||||
).unwrap_err()
|
||||
)?;
|
||||
|
||||
let sequences = sequences.into_iter().map(|sequence| {
|
||||
let mut total_len = 0;
|
||||
sequence.downcast::<JsArray>().or_throw(&mut cx)?
|
||||
.to_vec(&mut cx)?
|
||||
.into_iter()
|
||||
.map(|item| match mode {
|
||||
Mode::NoOffsets => {
|
||||
let s = item.downcast::<JsString>().or_throw(&mut cx)?.value();
|
||||
let len = s.chars().count();
|
||||
total_len += len;
|
||||
Ok((s, (total_len - len, total_len)))
|
||||
},
|
||||
Mode::Offsets => {
|
||||
let tuple = item.downcast::<JsArray>().or_throw(&mut cx)?;
|
||||
let s = tuple.get(&mut cx, 0)?
|
||||
.downcast::<JsString>()
|
||||
.or_throw(&mut cx)?
|
||||
.value();
|
||||
let offsets = tuple.get(&mut cx, 1)?
|
||||
.downcast::<JsArray>()
|
||||
.or_throw(&mut cx)?;
|
||||
let (start, end) = (
|
||||
offsets.get(&mut cx, 0)?
|
||||
.downcast::<JsNumber>()
|
||||
.or_throw(&mut cx)?
|
||||
.value() as usize,
|
||||
offsets.get(&mut cx, 1)?
|
||||
.downcast::<JsNumber>().
|
||||
or_throw(&mut cx)?
|
||||
.value() as usize,
|
||||
);
|
||||
Ok((s, (start, end)))
|
||||
}
|
||||
}).collect::<Result<Vec<_>, _>>()
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
let callback = cx.argument::<JsFunction>(2)?;
|
||||
|
||||
let worker = {
|
||||
let this = cx.this();
|
||||
let guard = cx.lock();
|
||||
let worker = this.borrow(&guard).prepare_for_task();
|
||||
worker
|
||||
};
|
||||
|
||||
let task = EncodeTokenizedTask::Batch(worker, Some(sequences), type_id);
|
||||
task.schedule(callback);
|
||||
Ok(cx.undefined().upcast())
|
||||
}
|
||||
|
||||
method decode(mut cx) {
|
||||
// decode(ids: number[], skipSpecialTokens: bool, callback)
|
||||
|
||||
|
Reference in New Issue
Block a user