diff --git a/bindings/node/native/src/encoding.rs b/bindings/node/native/src/encoding.rs index cd519d68..0ebe37af 100644 --- a/bindings/node/native/src/encoding.rs +++ b/bindings/node/native/src/encoding.rs @@ -1,5 +1,7 @@ extern crate tokenizers as tk; +use tk::tokenizer::PaddingDirection; + use crate::utils::Container; use neon::prelude::*; @@ -122,5 +124,92 @@ declare_types! { Ok(js_offsets.upcast()) } + + method getOverflowing(mut cx) { + // getOverflowing(): Encoding | undefined; + + let this = cx.this(); + let guard = cx.lock(); + + let overflowing = this.borrow(&guard).encoding.execute(|encoding| { + encoding.unwrap().get_overflowing().cloned() + }); + + if let Some(overflowing) = overflowing { + let mut js_overflowing = JsEncoding::new::<_, JsEncoding, _>(&mut cx, vec![])?; + + // Set the content + let guard = cx.lock(); + js_overflowing.borrow_mut(&guard).encoding.to_owned(Box::new(overflowing)); + + Ok(js_overflowing.upcast()) + } else { + Ok(cx.undefined().upcast()) + } + } + + method pad(mut cx) { + // pad(length: number, options?: { + // direction?: 'left' | 'right' = 'right', + // padId?: number = 0, + // padTypeId?: number = 0, + // padToken?: string = "[PAD]" + // } + + let length = cx.argument::(0)?.value() as usize; + let mut direction = PaddingDirection::Right; + let mut pad_id = 0; + let mut pad_type_id = 0; + let mut pad_token = String::from("[PAD]"); + + let options = cx.argument_opt(1); + if let Some(options) = options { + if let Ok(options) = options.downcast::() { + if let Ok(dir) = options.get(&mut cx, "direction") { + let dir = dir.downcast::().or_throw(&mut cx)?.value(); + match &dir[..] { + "right" => direction = PaddingDirection::Right, + "left" => direction = PaddingDirection::Left, + _ => return cx.throw_error("direction can be 'right' or 'left'"), + } + } + if let Ok(pid) = options.get(&mut cx, "padId") { + pad_id = pid.downcast::().or_throw(&mut cx)?.value() as u32; + } + if let Ok(pid) = options.get(&mut cx, "padTypeId") { + pad_type_id = pid.downcast::().or_throw(&mut cx)?.value() as u32; + } + if let Ok(token) = options.get(&mut cx, "padToken") { + pad_token = token.downcast::().or_throw(&mut cx)?.value(); + } + } + } + + let mut this = cx.this(); + let guard = cx.lock(); + this.borrow_mut(&guard).encoding.execute_mut(|encoding| { + encoding.unwrap().pad(length, pad_id, pad_type_id, &pad_token, &direction); + }); + + Ok(cx.undefined().upcast()) + } + + method truncate(mut cx) { + // truncate(length: number, stride: number = 0) + + let length = cx.argument::(0)?.value() as usize; + let mut stride = 0; + if let Some(args) = cx.argument_opt(1) { + stride = args.downcast::().or_throw(&mut cx)?.value() as usize; + } + + let mut this = cx.this(); + let guard = cx.lock(); + this.borrow_mut(&guard).encoding.execute_mut(|encoding| { + encoding.unwrap().truncate(length, stride); + }); + + Ok(cx.undefined().upcast()) + } } } diff --git a/bindings/node/native/src/utils.rs b/bindings/node/native/src/utils.rs index b31368d5..2a838c31 100644 --- a/bindings/node/native/src/utils.rs +++ b/bindings/node/native/src/utils.rs @@ -74,4 +74,21 @@ where Container::Empty => closure(None), } } + + pub fn execute_mut(&mut self, closure: F) -> U + where + F: FnOnce(Option<&mut Box>) -> U, + { + match self { + Container::Owned(val) => closure(Some(val)), + Container::Pointer(ptr) => unsafe { + let mut val = Box::from_raw(*ptr); + let res = closure(Some(&mut val)); + // We call this to make sure we don't drop the Box + Box::into_raw(val); + res + }, + Container::Empty => closure(None), + } + } }