diff --git a/bindings/node/native/Cargo.lock b/bindings/node/native/Cargo.lock index d2d5090e..39dff459 100644 --- a/bindings/node/native/Cargo.lock +++ b/bindings/node/native/Cargo.lock @@ -251,6 +251,7 @@ version = "0.1.0" dependencies = [ "neon 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)", "neon-build 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)", + "rayon 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)", "tokenizers 0.8.0", ] diff --git a/bindings/node/native/Cargo.toml b/bindings/node/native/Cargo.toml index b631ee2f..2e648d6d 100644 --- a/bindings/node/native/Cargo.toml +++ b/bindings/node/native/Cargo.toml @@ -15,4 +15,5 @@ neon-build = "0.3.3" [dependencies] neon = "0.3.3" +rayon = "1.2.0" tokenizers = { path = "../../../tokenizers" } diff --git a/bindings/node/native/src/lib.rs b/bindings/node/native/src/lib.rs index b51221c1..3bfb3084 100644 --- a/bindings/node/native/src/lib.rs +++ b/bindings/node/native/src/lib.rs @@ -1,6 +1,7 @@ #![warn(clippy::all)] extern crate neon; +extern crate rayon; extern crate tokenizers as tk; mod container; diff --git a/bindings/node/native/src/models.rs b/bindings/node/native/src/models.rs index 30e0bfde..d03bab26 100644 --- a/bindings/node/native/src/models.rs +++ b/bindings/node/native/src/models.rs @@ -1,7 +1,6 @@ extern crate tokenizers as tk; use crate::container::Container; -use crate::encoding::JsEncoding; use crate::tasks::models::{BPEFromFilesTask, WordPieceFromFilesTask}; use neon::prelude::*; use std::path::Path; @@ -52,83 +51,6 @@ declare_types! { Err(e) => cx.throw_error(format!("{}", e)) } } - - method encode(mut cx) { - /// encode(sequence: (String | [String, [number, number]])[], typeId?: number = 0): - /// Encoding - let sequence = cx.argument::(0)?.to_vec(&mut cx)?; - let type_id = cx.argument_opt(1) - .map_or(Some(0), |arg| arg.downcast::() - .ok() - .map(|h| h.value() as u32) - ).unwrap(); - - enum Mode { - NoOffsets, - Offsets, - }; - let mode = sequence.iter().next().map(|item| { - if item.downcast::().is_ok() { - Ok(Mode::NoOffsets) - } else if item.downcast::().is_ok() { - Ok(Mode::Offsets) - } else { - Err("Input must be (String | [String, [number, number]])[]") - } - }) - .unwrap() - .map_err(|e| cx.throw_error::<_, ()>(e.to_string()).unwrap_err())?; - - let mut total_len = 0; - let sequence = sequence.iter().map(|item| match mode { - Mode::NoOffsets => { - let s = item.downcast::().or_throw(&mut cx)?.value(); - let len = s.chars().count(); - total_len += len; - Ok((s, (total_len - len, total_len))) - }, - Mode::Offsets => { - let tuple = item.downcast::().or_throw(&mut cx)?; - let s = tuple.get(&mut cx, 0)? - .downcast::() - .or_throw(&mut cx)? - .value(); - let offsets = tuple.get(&mut cx, 1)? - .downcast::() - .or_throw(&mut cx)?; - let (start, end) = ( - offsets.get(&mut cx, 0)? - .downcast::() - .or_throw(&mut cx)? - .value() as usize, - offsets.get(&mut cx, 1)? - .downcast::(). - or_throw(&mut cx)? - .value() as usize, - ); - Ok((s, (start, end))) - } - }).collect::, _>>()?; - - let encoding = { - let this = cx.this(); - let guard = cx.lock(); - let res = this.borrow(&guard) - .model - .execute(|model| model.unwrap() - .tokenize(sequence) - .map(|tokens| tk::tokenizer::Encoding::from_tokens(tokens, type_id)) - ); - res.map_err(|e| cx.throw_error::<_, ()>(format!("{}", e)).unwrap_err())? - }; - - let mut js_encoding = JsEncoding::new::<_, JsEncoding, _>(&mut cx, vec![])?; - - let guard = cx.lock(); - js_encoding.borrow_mut(&guard).encoding.to_owned(Box::new(encoding)); - - Ok(js_encoding.upcast()) - } } } diff --git a/bindings/node/native/src/tasks/tokenizer.rs b/bindings/node/native/src/tasks/tokenizer.rs index 9b41e407..ad07a2a3 100644 --- a/bindings/node/native/src/tasks/tokenizer.rs +++ b/bindings/node/native/src/tasks/tokenizer.rs @@ -2,6 +2,7 @@ extern crate tokenizers as tk; use crate::encoding::*; use neon::prelude::*; +use rayon::prelude::*; use tk::tokenizer::{EncodeInput, Encoding, Tokenizer}; pub struct WorkingTokenizer { @@ -100,6 +101,97 @@ impl Task for EncodeTask { } } +pub enum EncodeTokenizedTask { + Single(WorkingTokenizer, Option>, u32), + Batch( + WorkingTokenizer, + Option>>, + u32, + ), +} + +pub enum EncodeTokenizedOutput { + Single(Encoding), + Batch(Vec), +} + +impl Task for EncodeTokenizedTask { + type Output = EncodeTokenizedOutput; + type Error = String; + type JsEvent = JsValue; + + fn perform(&self) -> Result { + match self { + EncodeTokenizedTask::Single(worker, input, type_id) => { + let input = unsafe { std::ptr::replace(input as *const _ as *mut _, None) }; + let tokenizer: &Tokenizer = unsafe { &*worker.ptr }; + + tokenizer + .get_model() + .tokenize(input.unwrap()) + .map_err(|e| format!("{}", e)) + .map(|tokens| { + EncodeTokenizedOutput::Single(Encoding::from_tokens(tokens, *type_id)) + }) + } + EncodeTokenizedTask::Batch(worker, input, type_id) => { + let input: Option> = + unsafe { std::ptr::replace(input as *const _ as *mut _, None) }; + let tokenizer: &Tokenizer = unsafe { &*worker.ptr }; + + input + .unwrap() + .into_par_iter() + .map(|input| { + tokenizer + .get_model() + .tokenize(input) + .map_err(|e| format!("{}", e)) + .map(|tokens| Encoding::from_tokens(tokens, *type_id)) + }) + .collect::>() + .map(EncodeTokenizedOutput::Batch) + } + } + } + + fn complete( + self, + mut cx: TaskContext, + result: Result, + ) -> JsResult { + match result.map_err(|e| cx.throw_error::<_, ()>(e).unwrap_err())? { + EncodeTokenizedOutput::Single(encoding) => { + let mut js_encoding = JsEncoding::new::<_, JsEncoding, _>(&mut cx, vec![])?; + // Set the actual encoding + let guard = cx.lock(); + js_encoding + .borrow_mut(&guard) + .encoding + .to_owned(Box::new(encoding)); + + Ok(js_encoding.upcast()) + } + EncodeTokenizedOutput::Batch(encodings) => { + let result = JsArray::new(&mut cx, encodings.len() as u32); + for (i, encoding) in encodings.into_iter().enumerate() { + let mut js_encoding = JsEncoding::new::<_, JsEncoding, _>(&mut cx, vec![])?; + + // Set the actual encoding + let guard = cx.lock(); + js_encoding + .borrow_mut(&guard) + .encoding + .to_owned(Box::new(encoding)); + + result.set(&mut cx, i as u32, js_encoding)?; + } + Ok(result.upcast()) + } + } + } +} + pub enum DecodeTask { Single(WorkingTokenizer, Vec, bool), Batch(WorkingTokenizer, Vec>, bool), diff --git a/bindings/node/native/src/tokenizer.rs b/bindings/node/native/src/tokenizer.rs index addb8356..7cd199a9 100644 --- a/bindings/node/native/src/tokenizer.rs +++ b/bindings/node/native/src/tokenizer.rs @@ -6,7 +6,7 @@ use crate::models::JsModel; use crate::normalizers::JsNormalizer; use crate::pre_tokenizers::JsPreTokenizer; use crate::processors::JsPostProcessor; -use crate::tasks::tokenizer::{DecodeTask, EncodeTask, WorkingTokenizer}; +use crate::tasks::tokenizer::{DecodeTask, EncodeTask, EncodeTokenizedTask, WorkingTokenizer}; use crate::trainers::JsTrainer; use neon::prelude::*; @@ -237,6 +237,170 @@ declare_types! { Ok(cx.undefined().upcast()) } + method encodeTokenized(mut cx) { + /// encodeTokenized( + /// sequence: (String | [String, [number, number]])[], + /// typeId?: number = 0, + /// callback: (err, Encoding) + /// ) + + let sequence = cx.argument::(0)?.to_vec(&mut cx)?; + let type_id = cx.argument_opt(1) + .map_or(Some(0), |arg| arg.downcast::() + .ok() + .map(|h| h.value() as u32) + ).unwrap(); + + enum Mode { + NoOffsets, + Offsets, + }; + let mode = sequence.iter().next().map(|item| { + if item.downcast::().is_ok() { + Ok(Mode::NoOffsets) + } else if item.downcast::().is_ok() { + Ok(Mode::Offsets) + } else { + Err("Input must be (String | [String, [number, number]])[]") + } + }) + .unwrap() + .map_err(|e| cx.throw_error::<_, ()>(e.to_string()).unwrap_err())?; + + let mut total_len = 0; + let sequence = sequence.iter().map(|item| match mode { + Mode::NoOffsets => { + let s = item.downcast::().or_throw(&mut cx)?.value(); + let len = s.chars().count(); + total_len += len; + Ok((s, (total_len - len, total_len))) + }, + Mode::Offsets => { + let tuple = item.downcast::().or_throw(&mut cx)?; + let s = tuple.get(&mut cx, 0)? + .downcast::() + .or_throw(&mut cx)? + .value(); + let offsets = tuple.get(&mut cx, 1)? + .downcast::() + .or_throw(&mut cx)?; + let (start, end) = ( + offsets.get(&mut cx, 0)? + .downcast::() + .or_throw(&mut cx)? + .value() as usize, + offsets.get(&mut cx, 1)? + .downcast::(). + or_throw(&mut cx)? + .value() as usize, + ); + Ok((s, (start, end))) + } + }).collect::, _>>()?; + let callback = cx.argument::(2)?; + + let worker = { + let this = cx.this(); + let guard = cx.lock(); + let worker = this.borrow(&guard).prepare_for_task(); + worker + }; + + let task = EncodeTokenizedTask::Single(worker, Some(sequence), type_id); + task.schedule(callback); + Ok(cx.undefined().upcast()) + } + + method encodeTokenizedBatch(mut cx) { + /// encodeTokenizedBatch( + /// sequences: (String | [String, [number, number]])[][], + /// typeId?: number = 0, + /// callback: (err, Encoding) + /// ) + + let sequences = cx.argument::(0)?.to_vec(&mut cx)?; + let type_id = cx.argument_opt(1) + .map_or(Some(0), |arg| arg.downcast::() + .ok() + .map(|h| h.value() as u32) + ).unwrap(); + + enum Mode { + NoOffsets, + Offsets, + }; + let mode = sequences.iter().next().map(|sequence| { + if let Ok(sequence) = sequence.downcast::().or_throw(&mut cx) { + sequence.to_vec(&mut cx).ok().map(|s| s.iter().next().map(|item| { + if item.downcast::().is_ok() { + Some(Mode::NoOffsets) + } else if item.downcast::().is_ok() { + Some(Mode::Offsets) + } else { + None + } + }).flatten()).flatten() + } else { + None + } + }) + .flatten() + .ok_or_else(|| + cx.throw_error::<_, ()>( + "Input must be (String | [String, [number, number]])[]" + ).unwrap_err() + )?; + + let sequences = sequences.into_iter().map(|sequence| { + let mut total_len = 0; + sequence.downcast::().or_throw(&mut cx)? + .to_vec(&mut cx)? + .into_iter() + .map(|item| match mode { + Mode::NoOffsets => { + let s = item.downcast::().or_throw(&mut cx)?.value(); + let len = s.chars().count(); + total_len += len; + Ok((s, (total_len - len, total_len))) + }, + Mode::Offsets => { + let tuple = item.downcast::().or_throw(&mut cx)?; + let s = tuple.get(&mut cx, 0)? + .downcast::() + .or_throw(&mut cx)? + .value(); + let offsets = tuple.get(&mut cx, 1)? + .downcast::() + .or_throw(&mut cx)?; + let (start, end) = ( + offsets.get(&mut cx, 0)? + .downcast::() + .or_throw(&mut cx)? + .value() as usize, + offsets.get(&mut cx, 1)? + .downcast::(). + or_throw(&mut cx)? + .value() as usize, + ); + Ok((s, (start, end))) + } + }).collect::, _>>() + }) + .collect::, _>>()?; + let callback = cx.argument::(2)?; + + let worker = { + let this = cx.this(); + let guard = cx.lock(); + let worker = this.borrow(&guard).prepare_for_task(); + worker + }; + + let task = EncodeTokenizedTask::Batch(worker, Some(sequences), type_id); + task.schedule(callback); + Ok(cx.undefined().upcast()) + } + method decode(mut cx) { // decode(ids: number[], skipSpecialTokens: bool, callback)