Adding a test to check that serialization of JsPreTokenizer fits

PreTokenizer serialization.
This commit is contained in:
Nicolas Patry
2020-09-02 16:56:18 +02:00
committed by Anthony MOI
parent 52e527dc9f
commit b8d6c1dece
5 changed files with 154 additions and 44 deletions

View File

@ -12,7 +12,7 @@ import { EncodeOptions, InputSequence, Tokenizer } from "./tokenizer";
const MOCKS_DIR = __dirname + "/__mocks__"; const MOCKS_DIR = __dirname + "/__mocks__";
describe("Encoding", () => { describe("Can modify pretokenizers on the fly", () => {
let encoding: RawEncoding; let encoding: RawEncoding;
let encode: ( let encode: (
sequence: InputSequence, sequence: InputSequence,
@ -30,28 +30,22 @@ describe("Encoding", () => {
); );
tokenizer = new Tokenizer(model); tokenizer = new Tokenizer(model);
tokenizer.setPreTokenizer(whitespacePreTokenizer());
encode = promisify(tokenizer.encode.bind(tokenizer)); encode = promisify(tokenizer.encode.bind(tokenizer));
}); });
it("Encodes correctly", async () => {
encoding = await encode("my name is john", null);
expect(encoding.getIds()).toEqual([0, 1, 2, 3, 4]);
encoding = await encode("my name is john", null);
expect(encoding.getIds()).toEqual([0, 1, 2, 3, 4]);
});
it("Can change pre tokenizer", async () => { it("Can change pre tokenizer", async () => {
const input = "my name is john.!?";
tokenizer.setPreTokenizer(sequencePreTokenizer([whitespacePreTokenizer()])); tokenizer.setPreTokenizer(sequencePreTokenizer([whitespacePreTokenizer()]));
encoding = await encode("my name is john.!?", null); encoding = await encode(input, null);
expect(encoding.getIds()).toEqual([0, 1, 2, 3, 4, 6]); expect(encoding.getIds()).toEqual([0, 1, 2, 3, 4, 6]);
// Change pre tokenizer
tokenizer.setPreTokenizer( tokenizer.setPreTokenizer(
sequencePreTokenizer([whitespacePreTokenizer(), punctuationPreTokenizer()]) sequencePreTokenizer([whitespacePreTokenizer(), punctuationPreTokenizer()])
); );
encoding = await encode("my name is john.!?", null); encoding = await encode(input, null);
expect(encoding.getIds()).toEqual([0, 1, 2, 3, 4, 6, 6, 6]); expect(encoding.getIds()).toEqual([0, 1, 2, 3, 4, 6, 6, 6]);
}); });
}); });

View File

@ -448,6 +448,7 @@ dependencies = [
"neon-runtime", "neon-runtime",
"neon-serde", "neon-serde",
"serde", "serde",
"serde_json",
"tokenizers", "tokenizers",
] ]

View File

@ -19,3 +19,4 @@ neon-runtime = "0.3"
neon-serde = "0.3" neon-serde = "0.3"
serde = { version = "1.0", features = [ "rc", "derive" ] } serde = { version = "1.0", features = [ "rc", "derive" ] }
tokenizers = { path = "../../../tokenizers" } tokenizers = { path = "../../../tokenizers" }
serde_json = "1.0"

View File

@ -8,7 +8,7 @@ use std::sync::Arc;
use tk::normalizers::NormalizerWrapper; use tk::normalizers::NormalizerWrapper;
use tk::NormalizedString; use tk::NormalizedString;
#[derive(Clone, Deserialize)] #[derive(Clone, Debug, Deserialize)]
pub enum JsNormalizerWrapper { pub enum JsNormalizerWrapper {
Sequence(Vec<Arc<NormalizerWrapper>>), Sequence(Vec<Arc<NormalizerWrapper>>),
Wrapped(Arc<NormalizerWrapper>), Wrapped(Arc<NormalizerWrapper>),
@ -213,3 +213,47 @@ pub fn register(m: &mut ModuleContext, prefix: &str) -> NeonResult<()> {
m.export_function(&format!("{}_Strip", prefix), strip)?; m.export_function(&format!("{}_Strip", prefix), strip)?;
Ok(()) Ok(())
} }
#[cfg(test)]
mod test {
use super::*;
use tk::normalizers::unicode::{NFC, NFKC};
use tk::normalizers::utils::Sequence;
use tk::normalizers::NormalizerWrapper;
#[test]
fn serialize() {
let js_wrapped: JsNormalizerWrapper = NFKC.into();
let js_ser = serde_json::to_string(&js_wrapped).unwrap();
let rs_wrapped = NormalizerWrapper::NFKC(NFKC);
let rs_ser = serde_json::to_string(&rs_wrapped).unwrap();
assert_eq!(js_ser, rs_ser);
// let js_norm: Normalizer = serde_json::from_str(&rs_ser).unwrap();
// match js_norm.normalizer.unwrap() {
// JsNormalizerWrapper::Wrapped(nfc) => match nfc.as_ref() {
// NormalizerWrapper::NFKC(_) => {}
// _ => panic!("Expected NFKC"),
// },
// _ => panic!("Expected wrapped, not sequence."),
// }
let js_seq: JsNormalizerWrapper = Sequence::new(vec![NFC.into(), NFKC.into()]).into();
let js_wrapper_ser = serde_json::to_string(&js_seq).unwrap();
let rs_wrapped =
NormalizerWrapper::Sequence(Sequence::new(vec![NFC.into(), NFKC.into()]).into());
let rs_ser = serde_json::to_string(&rs_wrapped).unwrap();
assert_eq!(js_wrapper_ser, rs_ser);
let js_seq = Normalizer {
normalizer: Some(js_seq),
};
let js_ser = serde_json::to_string(&js_seq).unwrap();
assert_eq!(js_wrapper_ser, js_ser);
let rs_seq = Sequence::new(vec![NFC.into(), NFKC.into()]);
let rs_ser = serde_json::to_string(&rs_seq).unwrap();
assert_eq!(js_wrapper_ser, rs_ser);
}
}

View File

@ -4,22 +4,63 @@ use crate::extraction::*;
use neon::prelude::*; use neon::prelude::*;
use std::sync::Arc; use std::sync::Arc;
use serde::{ser::SerializeStruct, Serialize, Serializer};
use tk::pre_tokenizers::PreTokenizerWrapper; use tk::pre_tokenizers::PreTokenizerWrapper;
use tk::PreTokenizedString; use tk::PreTokenizedString;
#[derive(Clone, Debug, Deserialize)]
pub enum JsPreTokenizerWrapper {
Sequence(Vec<Arc<PreTokenizerWrapper>>),
Wrapped(Arc<PreTokenizerWrapper>),
}
impl Serialize for JsPreTokenizerWrapper {
fn serialize<S>(&self, serializer: S) -> Result<<S as Serializer>::Ok, <S as Serializer>::Error>
where
S: Serializer,
{
match self {
JsPreTokenizerWrapper::Sequence(seq) => {
let mut ser = serializer.serialize_struct("Sequence", 2)?;
ser.serialize_field("type", "Sequence")?;
ser.serialize_field("pretokenizers", seq)?;
ser.end()
}
JsPreTokenizerWrapper::Wrapped(inner) => inner.serialize(serializer),
}
}
}
impl<I> From<I> for JsPreTokenizerWrapper
where
I: Into<PreTokenizerWrapper>,
{
fn from(norm: I) -> Self {
JsPreTokenizerWrapper::Wrapped(Arc::new(norm.into()))
}
}
/// PreTokenizers /// PreTokenizers
#[derive(Clone, Serialize, Deserialize, Debug)] #[derive(Clone, Serialize, Deserialize, Debug)]
pub struct PreTokenizer { pub struct PreTokenizer {
#[serde(flatten)] #[serde(flatten)]
pub pretok: Option<Arc<PreTokenizerWrapper>>, pub pretok: Option<JsPreTokenizerWrapper>,
} }
impl tk::PreTokenizer for PreTokenizer { impl tk::PreTokenizer for PreTokenizer {
fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> tk::Result<()> { fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> tk::Result<()> {
self.pretok match self.pretok.as_ref().ok_or("Uninitialized PreTokenizer")? {
.as_ref() JsPreTokenizerWrapper::Sequence(seq) => {
.ok_or("Uninitialized PreTokenizer")? for pretokenizer in seq {
.pre_tokenize(pretokenized) pretokenizer.pre_tokenize(pretokenized)?;
}
}
JsPreTokenizerWrapper::Wrapped(pretokenizer) => {
pretokenizer.pre_tokenize(pretokenized)?
}
};
Ok(())
} }
} }
@ -41,7 +82,7 @@ fn byte_level(mut cx: FunctionContext) -> JsResult<JsPreTokenizer> {
let mut pretok = JsPreTokenizer::new::<_, JsPreTokenizer, _>(&mut cx, vec![])?; let mut pretok = JsPreTokenizer::new::<_, JsPreTokenizer, _>(&mut cx, vec![])?;
let guard = cx.lock(); let guard = cx.lock();
pretok.borrow_mut(&guard).pretok = Some(Arc::new(byte_level.into())); pretok.borrow_mut(&guard).pretok = Some(byte_level.into());
Ok(pretok) Ok(pretok)
} }
@ -59,9 +100,8 @@ fn byte_level_alphabet(mut cx: FunctionContext) -> JsResult<JsValue> {
fn whitespace(mut cx: FunctionContext) -> JsResult<JsPreTokenizer> { fn whitespace(mut cx: FunctionContext) -> JsResult<JsPreTokenizer> {
let mut pretok = JsPreTokenizer::new::<_, JsPreTokenizer, _>(&mut cx, vec![])?; let mut pretok = JsPreTokenizer::new::<_, JsPreTokenizer, _>(&mut cx, vec![])?;
let guard = cx.lock(); let guard = cx.lock();
pretok.borrow_mut(&guard).pretok = Some(Arc::new( pretok.borrow_mut(&guard).pretok =
tk::pre_tokenizers::whitespace::Whitespace::default().into(), Some(tk::pre_tokenizers::whitespace::Whitespace::default().into());
));
Ok(pretok) Ok(pretok)
} }
@ -69,9 +109,7 @@ fn whitespace(mut cx: FunctionContext) -> JsResult<JsPreTokenizer> {
fn whitespace_split(mut cx: FunctionContext) -> JsResult<JsPreTokenizer> { fn whitespace_split(mut cx: FunctionContext) -> JsResult<JsPreTokenizer> {
let mut pretok = JsPreTokenizer::new::<_, JsPreTokenizer, _>(&mut cx, vec![])?; let mut pretok = JsPreTokenizer::new::<_, JsPreTokenizer, _>(&mut cx, vec![])?;
let guard = cx.lock(); let guard = cx.lock();
pretok.borrow_mut(&guard).pretok = Some(Arc::new( pretok.borrow_mut(&guard).pretok = Some(tk::pre_tokenizers::whitespace::WhitespaceSplit.into());
tk::pre_tokenizers::whitespace::WhitespaceSplit.into(),
));
Ok(pretok) Ok(pretok)
} }
@ -79,8 +117,7 @@ fn whitespace_split(mut cx: FunctionContext) -> JsResult<JsPreTokenizer> {
fn bert_pre_tokenizer(mut cx: FunctionContext) -> JsResult<JsPreTokenizer> { fn bert_pre_tokenizer(mut cx: FunctionContext) -> JsResult<JsPreTokenizer> {
let mut pretok = JsPreTokenizer::new::<_, JsPreTokenizer, _>(&mut cx, vec![])?; let mut pretok = JsPreTokenizer::new::<_, JsPreTokenizer, _>(&mut cx, vec![])?;
let guard = cx.lock(); let guard = cx.lock();
pretok.borrow_mut(&guard).pretok = pretok.borrow_mut(&guard).pretok = Some(tk::pre_tokenizers::bert::BertPreTokenizer.into());
Some(Arc::new(tk::pre_tokenizers::bert::BertPreTokenizer.into()));
Ok(pretok) Ok(pretok)
} }
@ -91,9 +128,8 @@ fn metaspace(mut cx: FunctionContext) -> JsResult<JsPreTokenizer> {
let mut pretok = JsPreTokenizer::new::<_, JsPreTokenizer, _>(&mut cx, vec![])?; let mut pretok = JsPreTokenizer::new::<_, JsPreTokenizer, _>(&mut cx, vec![])?;
let guard = cx.lock(); let guard = cx.lock();
pretok.borrow_mut(&guard).pretok = Some(Arc::new( pretok.borrow_mut(&guard).pretok =
tk::pre_tokenizers::metaspace::Metaspace::new(replacement, add_prefix_space).into(), Some(tk::pre_tokenizers::metaspace::Metaspace::new(replacement, add_prefix_space).into());
));
Ok(pretok) Ok(pretok)
} }
@ -101,9 +137,7 @@ fn metaspace(mut cx: FunctionContext) -> JsResult<JsPreTokenizer> {
fn punctuation(mut cx: FunctionContext) -> JsResult<JsPreTokenizer> { fn punctuation(mut cx: FunctionContext) -> JsResult<JsPreTokenizer> {
let mut pretok = JsPreTokenizer::new::<_, JsPreTokenizer, _>(&mut cx, vec![])?; let mut pretok = JsPreTokenizer::new::<_, JsPreTokenizer, _>(&mut cx, vec![])?;
let guard = cx.lock(); let guard = cx.lock();
pretok.borrow_mut(&guard).pretok = Some(Arc::new( pretok.borrow_mut(&guard).pretok = Some(tk::pre_tokenizers::punctuation::Punctuation.into());
tk::pre_tokenizers::punctuation::Punctuation.into(),
));
Ok(pretok) Ok(pretok)
} }
@ -118,13 +152,15 @@ fn sequence(mut cx: FunctionContext) -> JsResult<JsPreTokenizer> {
|pretokenizer| match pretokenizer.downcast::<JsPreTokenizer>().or_throw(&mut cx) { |pretokenizer| match pretokenizer.downcast::<JsPreTokenizer>().or_throw(&mut cx) {
Ok(pretokenizer) => { Ok(pretokenizer) => {
let guard = cx.lock(); let guard = cx.lock();
let pretokenizer = (*pretokenizer.borrow(&guard)).pretok.clone(); let pretok = (*pretokenizer.borrow(&guard)).pretok.clone();
if let Some(pretokenizer) = pretokenizer { if let Some(pretokenizer) = pretok {
let pretok = (*pretokenizer).clone(); match pretokenizer {
sequence.push(pretok); JsPreTokenizerWrapper::Sequence(seq) => sequence.extend(seq),
JsPreTokenizerWrapper::Wrapped(inner) => sequence.push(inner),
}
Ok(()) Ok(())
} else { } else {
cx.throw_error("Uninitialized Normalizer") cx.throw_error("Uninitialized PreTokenizer")
} }
} }
Err(e) => Err(e), Err(e) => Err(e),
@ -133,9 +169,7 @@ fn sequence(mut cx: FunctionContext) -> JsResult<JsPreTokenizer> {
.collect::<NeonResult<_>>()?; .collect::<NeonResult<_>>()?;
let mut pretok = JsPreTokenizer::new::<_, JsPreTokenizer, _>(&mut cx, vec![])?; let mut pretok = JsPreTokenizer::new::<_, JsPreTokenizer, _>(&mut cx, vec![])?;
let guard = cx.lock(); let guard = cx.lock();
pretok.borrow_mut(&guard).pretok = Some(Arc::new( pretok.borrow_mut(&guard).pretok = Some(JsPreTokenizerWrapper::Sequence(sequence).into());
tk::pre_tokenizers::sequence::Sequence::new(sequence).into(),
));
Ok(pretok) Ok(pretok)
} }
@ -145,9 +179,8 @@ fn char_delimiter_split(mut cx: FunctionContext) -> JsResult<JsPreTokenizer> {
let mut pretok = JsPreTokenizer::new::<_, JsPreTokenizer, _>(&mut cx, vec![])?; let mut pretok = JsPreTokenizer::new::<_, JsPreTokenizer, _>(&mut cx, vec![])?;
let guard = cx.lock(); let guard = cx.lock();
pretok.borrow_mut(&guard).pretok = Some(Arc::new( pretok.borrow_mut(&guard).pretok =
tk::pre_tokenizers::delimiter::CharDelimiterSplit::new(delimiter).into(), Some(tk::pre_tokenizers::delimiter::CharDelimiterSplit::new(delimiter).into());
));
Ok(pretok) Ok(pretok)
} }
@ -171,3 +204,40 @@ pub fn register(m: &mut ModuleContext, prefix: &str) -> NeonResult<()> {
m.export_function(&format!("{}_Sequence", prefix), sequence)?; m.export_function(&format!("{}_Sequence", prefix), sequence)?;
Ok(()) Ok(())
} }
#[cfg(test)]
mod test {
use super::*;
use tk::pre_tokenizers::sequence::Sequence;
use tk::pre_tokenizers::whitespace::{Whitespace, WhitespaceSplit};
use tk::pre_tokenizers::PreTokenizerWrapper;
#[test]
fn serialize() {
let js_wrapped: JsPreTokenizerWrapper = Whitespace::default().into();
let js_ser = serde_json::to_string(&js_wrapped).unwrap();
let rs_wrapped = PreTokenizerWrapper::Whitespace(Whitespace::default());
let rs_ser = serde_json::to_string(&rs_wrapped).unwrap();
assert_eq!(js_ser, rs_ser);
let js_seq: JsPreTokenizerWrapper =
Sequence::new(vec![WhitespaceSplit.into(), Whitespace::default().into()]).into();
let js_wrapper_ser = serde_json::to_string(&js_seq).unwrap();
let rs_wrapped = PreTokenizerWrapper::Sequence(
Sequence::new(vec![WhitespaceSplit.into(), Whitespace::default().into()]).into(),
);
let rs_ser = serde_json::to_string(&rs_wrapped).unwrap();
assert_eq!(js_wrapper_ser, rs_ser);
let js_seq = PreTokenizer {
pretok: Some(js_seq),
};
let js_ser = serde_json::to_string(&js_seq).unwrap();
assert_eq!(js_wrapper_ser, js_ser);
let rs_seq = Sequence::new(vec![WhitespaceSplit.into(), Whitespace::default().into()]);
let rs_ser = serde_json::to_string(&rs_seq).unwrap();
assert_eq!(js_wrapper_ser, rs_ser);
}
}