From b8d6c1deced1c472ab2f01167abd8c3428c46189 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 2 Sep 2020 16:56:18 +0200 Subject: [PATCH] Adding a test to check that serialization of JsPreTokenizer fits PreTokenizer serialization. --- .../node/lib/bindings/raw-encoding.test.ts | 18 +-- bindings/node/native/Cargo.lock | 1 + bindings/node/native/Cargo.toml | 1 + bindings/node/native/src/normalizers.rs | 46 +++++- bindings/node/native/src/pre_tokenizers.rs | 132 ++++++++++++++---- 5 files changed, 154 insertions(+), 44 deletions(-) diff --git a/bindings/node/lib/bindings/raw-encoding.test.ts b/bindings/node/lib/bindings/raw-encoding.test.ts index f679a6db..aaa47b71 100644 --- a/bindings/node/lib/bindings/raw-encoding.test.ts +++ b/bindings/node/lib/bindings/raw-encoding.test.ts @@ -12,7 +12,7 @@ import { EncodeOptions, InputSequence, Tokenizer } from "./tokenizer"; const MOCKS_DIR = __dirname + "/__mocks__"; -describe("Encoding", () => { +describe("Can modify pretokenizers on the fly", () => { let encoding: RawEncoding; let encode: ( sequence: InputSequence, @@ -30,28 +30,22 @@ describe("Encoding", () => { ); tokenizer = new Tokenizer(model); - tokenizer.setPreTokenizer(whitespacePreTokenizer()); encode = promisify(tokenizer.encode.bind(tokenizer)); }); - it("Encodes correctly", async () => { - encoding = await encode("my name is john", null); - expect(encoding.getIds()).toEqual([0, 1, 2, 3, 4]); - - encoding = await encode("my name is john", null); - expect(encoding.getIds()).toEqual([0, 1, 2, 3, 4]); - }); - it("Can change pre tokenizer", async () => { + const input = "my name is john.!?"; tokenizer.setPreTokenizer(sequencePreTokenizer([whitespacePreTokenizer()])); - encoding = await encode("my name is john.!?", null); + encoding = await encode(input, null); expect(encoding.getIds()).toEqual([0, 1, 2, 3, 4, 6]); + + // Change pre tokenizer tokenizer.setPreTokenizer( sequencePreTokenizer([whitespacePreTokenizer(), punctuationPreTokenizer()]) ); - encoding = await encode("my name is john.!?", null); + encoding = await encode(input, null); expect(encoding.getIds()).toEqual([0, 1, 2, 3, 4, 6, 6, 6]); }); }); diff --git a/bindings/node/native/Cargo.lock b/bindings/node/native/Cargo.lock index 2036eda9..2c8dc80d 100644 --- a/bindings/node/native/Cargo.lock +++ b/bindings/node/native/Cargo.lock @@ -448,6 +448,7 @@ dependencies = [ "neon-runtime", "neon-serde", "serde", + "serde_json", "tokenizers", ] diff --git a/bindings/node/native/Cargo.toml b/bindings/node/native/Cargo.toml index 1f488081..1ea547ad 100644 --- a/bindings/node/native/Cargo.toml +++ b/bindings/node/native/Cargo.toml @@ -19,3 +19,4 @@ neon-runtime = "0.3" neon-serde = "0.3" serde = { version = "1.0", features = [ "rc", "derive" ] } tokenizers = { path = "../../../tokenizers" } +serde_json = "1.0" diff --git a/bindings/node/native/src/normalizers.rs b/bindings/node/native/src/normalizers.rs index 48e8c1a5..4a540497 100644 --- a/bindings/node/native/src/normalizers.rs +++ b/bindings/node/native/src/normalizers.rs @@ -8,7 +8,7 @@ use std::sync::Arc; use tk::normalizers::NormalizerWrapper; use tk::NormalizedString; -#[derive(Clone, Deserialize)] +#[derive(Clone, Debug, Deserialize)] pub enum JsNormalizerWrapper { Sequence(Vec>), Wrapped(Arc), @@ -213,3 +213,47 @@ pub fn register(m: &mut ModuleContext, prefix: &str) -> NeonResult<()> { m.export_function(&format!("{}_Strip", prefix), strip)?; Ok(()) } + +#[cfg(test)] +mod test { + use super::*; + use tk::normalizers::unicode::{NFC, NFKC}; + use tk::normalizers::utils::Sequence; + use tk::normalizers::NormalizerWrapper; + + #[test] + fn serialize() { + let js_wrapped: JsNormalizerWrapper = NFKC.into(); + let js_ser = serde_json::to_string(&js_wrapped).unwrap(); + + let rs_wrapped = NormalizerWrapper::NFKC(NFKC); + let rs_ser = serde_json::to_string(&rs_wrapped).unwrap(); + assert_eq!(js_ser, rs_ser); + + // let js_norm: Normalizer = serde_json::from_str(&rs_ser).unwrap(); + // match js_norm.normalizer.unwrap() { + // JsNormalizerWrapper::Wrapped(nfc) => match nfc.as_ref() { + // NormalizerWrapper::NFKC(_) => {} + // _ => panic!("Expected NFKC"), + // }, + // _ => panic!("Expected wrapped, not sequence."), + // } + + let js_seq: JsNormalizerWrapper = Sequence::new(vec![NFC.into(), NFKC.into()]).into(); + let js_wrapper_ser = serde_json::to_string(&js_seq).unwrap(); + let rs_wrapped = + NormalizerWrapper::Sequence(Sequence::new(vec![NFC.into(), NFKC.into()]).into()); + let rs_ser = serde_json::to_string(&rs_wrapped).unwrap(); + assert_eq!(js_wrapper_ser, rs_ser); + + let js_seq = Normalizer { + normalizer: Some(js_seq), + }; + let js_ser = serde_json::to_string(&js_seq).unwrap(); + assert_eq!(js_wrapper_ser, js_ser); + + let rs_seq = Sequence::new(vec![NFC.into(), NFKC.into()]); + let rs_ser = serde_json::to_string(&rs_seq).unwrap(); + assert_eq!(js_wrapper_ser, rs_ser); + } +} diff --git a/bindings/node/native/src/pre_tokenizers.rs b/bindings/node/native/src/pre_tokenizers.rs index 11da1964..ea270b55 100644 --- a/bindings/node/native/src/pre_tokenizers.rs +++ b/bindings/node/native/src/pre_tokenizers.rs @@ -4,22 +4,63 @@ use crate::extraction::*; use neon::prelude::*; use std::sync::Arc; +use serde::{ser::SerializeStruct, Serialize, Serializer}; use tk::pre_tokenizers::PreTokenizerWrapper; use tk::PreTokenizedString; +#[derive(Clone, Debug, Deserialize)] +pub enum JsPreTokenizerWrapper { + Sequence(Vec>), + Wrapped(Arc), +} + +impl Serialize for JsPreTokenizerWrapper { + fn serialize(&self, serializer: S) -> Result<::Ok, ::Error> + where + S: Serializer, + { + match self { + JsPreTokenizerWrapper::Sequence(seq) => { + let mut ser = serializer.serialize_struct("Sequence", 2)?; + ser.serialize_field("type", "Sequence")?; + ser.serialize_field("pretokenizers", seq)?; + ser.end() + } + JsPreTokenizerWrapper::Wrapped(inner) => inner.serialize(serializer), + } + } +} + +impl From for JsPreTokenizerWrapper +where + I: Into, +{ + fn from(norm: I) -> Self { + JsPreTokenizerWrapper::Wrapped(Arc::new(norm.into())) + } +} + /// PreTokenizers #[derive(Clone, Serialize, Deserialize, Debug)] pub struct PreTokenizer { #[serde(flatten)] - pub pretok: Option>, + pub pretok: Option, } impl tk::PreTokenizer for PreTokenizer { fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> tk::Result<()> { - self.pretok - .as_ref() - .ok_or("Uninitialized PreTokenizer")? - .pre_tokenize(pretokenized) + match self.pretok.as_ref().ok_or("Uninitialized PreTokenizer")? { + JsPreTokenizerWrapper::Sequence(seq) => { + for pretokenizer in seq { + pretokenizer.pre_tokenize(pretokenized)?; + } + } + JsPreTokenizerWrapper::Wrapped(pretokenizer) => { + pretokenizer.pre_tokenize(pretokenized)? + } + }; + + Ok(()) } } @@ -41,7 +82,7 @@ fn byte_level(mut cx: FunctionContext) -> JsResult { let mut pretok = JsPreTokenizer::new::<_, JsPreTokenizer, _>(&mut cx, vec![])?; let guard = cx.lock(); - pretok.borrow_mut(&guard).pretok = Some(Arc::new(byte_level.into())); + pretok.borrow_mut(&guard).pretok = Some(byte_level.into()); Ok(pretok) } @@ -59,9 +100,8 @@ fn byte_level_alphabet(mut cx: FunctionContext) -> JsResult { fn whitespace(mut cx: FunctionContext) -> JsResult { let mut pretok = JsPreTokenizer::new::<_, JsPreTokenizer, _>(&mut cx, vec![])?; let guard = cx.lock(); - pretok.borrow_mut(&guard).pretok = Some(Arc::new( - tk::pre_tokenizers::whitespace::Whitespace::default().into(), - )); + pretok.borrow_mut(&guard).pretok = + Some(tk::pre_tokenizers::whitespace::Whitespace::default().into()); Ok(pretok) } @@ -69,9 +109,7 @@ fn whitespace(mut cx: FunctionContext) -> JsResult { fn whitespace_split(mut cx: FunctionContext) -> JsResult { let mut pretok = JsPreTokenizer::new::<_, JsPreTokenizer, _>(&mut cx, vec![])?; let guard = cx.lock(); - pretok.borrow_mut(&guard).pretok = Some(Arc::new( - tk::pre_tokenizers::whitespace::WhitespaceSplit.into(), - )); + pretok.borrow_mut(&guard).pretok = Some(tk::pre_tokenizers::whitespace::WhitespaceSplit.into()); Ok(pretok) } @@ -79,8 +117,7 @@ fn whitespace_split(mut cx: FunctionContext) -> JsResult { fn bert_pre_tokenizer(mut cx: FunctionContext) -> JsResult { let mut pretok = JsPreTokenizer::new::<_, JsPreTokenizer, _>(&mut cx, vec![])?; let guard = cx.lock(); - pretok.borrow_mut(&guard).pretok = - Some(Arc::new(tk::pre_tokenizers::bert::BertPreTokenizer.into())); + pretok.borrow_mut(&guard).pretok = Some(tk::pre_tokenizers::bert::BertPreTokenizer.into()); Ok(pretok) } @@ -91,9 +128,8 @@ fn metaspace(mut cx: FunctionContext) -> JsResult { let mut pretok = JsPreTokenizer::new::<_, JsPreTokenizer, _>(&mut cx, vec![])?; let guard = cx.lock(); - pretok.borrow_mut(&guard).pretok = Some(Arc::new( - tk::pre_tokenizers::metaspace::Metaspace::new(replacement, add_prefix_space).into(), - )); + pretok.borrow_mut(&guard).pretok = + Some(tk::pre_tokenizers::metaspace::Metaspace::new(replacement, add_prefix_space).into()); Ok(pretok) } @@ -101,9 +137,7 @@ fn metaspace(mut cx: FunctionContext) -> JsResult { fn punctuation(mut cx: FunctionContext) -> JsResult { let mut pretok = JsPreTokenizer::new::<_, JsPreTokenizer, _>(&mut cx, vec![])?; let guard = cx.lock(); - pretok.borrow_mut(&guard).pretok = Some(Arc::new( - tk::pre_tokenizers::punctuation::Punctuation.into(), - )); + pretok.borrow_mut(&guard).pretok = Some(tk::pre_tokenizers::punctuation::Punctuation.into()); Ok(pretok) } @@ -118,13 +152,15 @@ fn sequence(mut cx: FunctionContext) -> JsResult { |pretokenizer| match pretokenizer.downcast::().or_throw(&mut cx) { Ok(pretokenizer) => { let guard = cx.lock(); - let pretokenizer = (*pretokenizer.borrow(&guard)).pretok.clone(); - if let Some(pretokenizer) = pretokenizer { - let pretok = (*pretokenizer).clone(); - sequence.push(pretok); + let pretok = (*pretokenizer.borrow(&guard)).pretok.clone(); + if let Some(pretokenizer) = pretok { + match pretokenizer { + JsPreTokenizerWrapper::Sequence(seq) => sequence.extend(seq), + JsPreTokenizerWrapper::Wrapped(inner) => sequence.push(inner), + } Ok(()) } else { - cx.throw_error("Uninitialized Normalizer") + cx.throw_error("Uninitialized PreTokenizer") } } Err(e) => Err(e), @@ -133,9 +169,7 @@ fn sequence(mut cx: FunctionContext) -> JsResult { .collect::>()?; let mut pretok = JsPreTokenizer::new::<_, JsPreTokenizer, _>(&mut cx, vec![])?; let guard = cx.lock(); - pretok.borrow_mut(&guard).pretok = Some(Arc::new( - tk::pre_tokenizers::sequence::Sequence::new(sequence).into(), - )); + pretok.borrow_mut(&guard).pretok = Some(JsPreTokenizerWrapper::Sequence(sequence).into()); Ok(pretok) } @@ -145,9 +179,8 @@ fn char_delimiter_split(mut cx: FunctionContext) -> JsResult { let mut pretok = JsPreTokenizer::new::<_, JsPreTokenizer, _>(&mut cx, vec![])?; let guard = cx.lock(); - pretok.borrow_mut(&guard).pretok = Some(Arc::new( - tk::pre_tokenizers::delimiter::CharDelimiterSplit::new(delimiter).into(), - )); + pretok.borrow_mut(&guard).pretok = + Some(tk::pre_tokenizers::delimiter::CharDelimiterSplit::new(delimiter).into()); Ok(pretok) } @@ -171,3 +204,40 @@ pub fn register(m: &mut ModuleContext, prefix: &str) -> NeonResult<()> { m.export_function(&format!("{}_Sequence", prefix), sequence)?; Ok(()) } + +#[cfg(test)] +mod test { + use super::*; + use tk::pre_tokenizers::sequence::Sequence; + use tk::pre_tokenizers::whitespace::{Whitespace, WhitespaceSplit}; + use tk::pre_tokenizers::PreTokenizerWrapper; + + #[test] + fn serialize() { + let js_wrapped: JsPreTokenizerWrapper = Whitespace::default().into(); + let js_ser = serde_json::to_string(&js_wrapped).unwrap(); + + let rs_wrapped = PreTokenizerWrapper::Whitespace(Whitespace::default()); + let rs_ser = serde_json::to_string(&rs_wrapped).unwrap(); + assert_eq!(js_ser, rs_ser); + + let js_seq: JsPreTokenizerWrapper = + Sequence::new(vec![WhitespaceSplit.into(), Whitespace::default().into()]).into(); + let js_wrapper_ser = serde_json::to_string(&js_seq).unwrap(); + let rs_wrapped = PreTokenizerWrapper::Sequence( + Sequence::new(vec![WhitespaceSplit.into(), Whitespace::default().into()]).into(), + ); + let rs_ser = serde_json::to_string(&rs_wrapped).unwrap(); + assert_eq!(js_wrapper_ser, rs_ser); + + let js_seq = PreTokenizer { + pretok: Some(js_seq), + }; + let js_ser = serde_json::to_string(&js_seq).unwrap(); + assert_eq!(js_wrapper_ser, js_ser); + + let rs_seq = Sequence::new(vec![WhitespaceSplit.into(), Whitespace::default().into()]); + let rs_ser = serde_json::to_string(&rs_seq).unwrap(); + assert_eq!(js_wrapper_ser, rs_ser); + } +}