mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-23 16:49:27 +00:00
Adding a test to check that serialization of JsPreTokenizer fits
PreTokenizer serialization.
This commit is contained in:
committed by
Anthony MOI
parent
52e527dc9f
commit
b8d6c1dece
@ -12,7 +12,7 @@ import { EncodeOptions, InputSequence, Tokenizer } from "./tokenizer";
|
|||||||
|
|
||||||
const MOCKS_DIR = __dirname + "/__mocks__";
|
const MOCKS_DIR = __dirname + "/__mocks__";
|
||||||
|
|
||||||
describe("Encoding", () => {
|
describe("Can modify pretokenizers on the fly", () => {
|
||||||
let encoding: RawEncoding;
|
let encoding: RawEncoding;
|
||||||
let encode: (
|
let encode: (
|
||||||
sequence: InputSequence,
|
sequence: InputSequence,
|
||||||
@ -30,28 +30,22 @@ describe("Encoding", () => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
tokenizer = new Tokenizer(model);
|
tokenizer = new Tokenizer(model);
|
||||||
tokenizer.setPreTokenizer(whitespacePreTokenizer());
|
|
||||||
encode = promisify(tokenizer.encode.bind(tokenizer));
|
encode = promisify(tokenizer.encode.bind(tokenizer));
|
||||||
});
|
});
|
||||||
|
|
||||||
it("Encodes correctly", async () => {
|
|
||||||
encoding = await encode("my name is john", null);
|
|
||||||
expect(encoding.getIds()).toEqual([0, 1, 2, 3, 4]);
|
|
||||||
|
|
||||||
encoding = await encode("my name is john", null);
|
|
||||||
expect(encoding.getIds()).toEqual([0, 1, 2, 3, 4]);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("Can change pre tokenizer", async () => {
|
it("Can change pre tokenizer", async () => {
|
||||||
|
const input = "my name is john.!?";
|
||||||
tokenizer.setPreTokenizer(sequencePreTokenizer([whitespacePreTokenizer()]));
|
tokenizer.setPreTokenizer(sequencePreTokenizer([whitespacePreTokenizer()]));
|
||||||
|
|
||||||
encoding = await encode("my name is john.!?", null);
|
encoding = await encode(input, null);
|
||||||
expect(encoding.getIds()).toEqual([0, 1, 2, 3, 4, 6]);
|
expect(encoding.getIds()).toEqual([0, 1, 2, 3, 4, 6]);
|
||||||
|
|
||||||
|
// Change pre tokenizer
|
||||||
tokenizer.setPreTokenizer(
|
tokenizer.setPreTokenizer(
|
||||||
sequencePreTokenizer([whitespacePreTokenizer(), punctuationPreTokenizer()])
|
sequencePreTokenizer([whitespacePreTokenizer(), punctuationPreTokenizer()])
|
||||||
);
|
);
|
||||||
|
|
||||||
encoding = await encode("my name is john.!?", null);
|
encoding = await encode(input, null);
|
||||||
expect(encoding.getIds()).toEqual([0, 1, 2, 3, 4, 6, 6, 6]);
|
expect(encoding.getIds()).toEqual([0, 1, 2, 3, 4, 6, 6, 6]);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
1
bindings/node/native/Cargo.lock
generated
1
bindings/node/native/Cargo.lock
generated
@ -448,6 +448,7 @@ dependencies = [
|
|||||||
"neon-runtime",
|
"neon-runtime",
|
||||||
"neon-serde",
|
"neon-serde",
|
||||||
"serde",
|
"serde",
|
||||||
|
"serde_json",
|
||||||
"tokenizers",
|
"tokenizers",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -19,3 +19,4 @@ neon-runtime = "0.3"
|
|||||||
neon-serde = "0.3"
|
neon-serde = "0.3"
|
||||||
serde = { version = "1.0", features = [ "rc", "derive" ] }
|
serde = { version = "1.0", features = [ "rc", "derive" ] }
|
||||||
tokenizers = { path = "../../../tokenizers" }
|
tokenizers = { path = "../../../tokenizers" }
|
||||||
|
serde_json = "1.0"
|
||||||
|
@ -8,7 +8,7 @@ use std::sync::Arc;
|
|||||||
use tk::normalizers::NormalizerWrapper;
|
use tk::normalizers::NormalizerWrapper;
|
||||||
use tk::NormalizedString;
|
use tk::NormalizedString;
|
||||||
|
|
||||||
#[derive(Clone, Deserialize)]
|
#[derive(Clone, Debug, Deserialize)]
|
||||||
pub enum JsNormalizerWrapper {
|
pub enum JsNormalizerWrapper {
|
||||||
Sequence(Vec<Arc<NormalizerWrapper>>),
|
Sequence(Vec<Arc<NormalizerWrapper>>),
|
||||||
Wrapped(Arc<NormalizerWrapper>),
|
Wrapped(Arc<NormalizerWrapper>),
|
||||||
@ -213,3 +213,47 @@ pub fn register(m: &mut ModuleContext, prefix: &str) -> NeonResult<()> {
|
|||||||
m.export_function(&format!("{}_Strip", prefix), strip)?;
|
m.export_function(&format!("{}_Strip", prefix), strip)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod test {
|
||||||
|
use super::*;
|
||||||
|
use tk::normalizers::unicode::{NFC, NFKC};
|
||||||
|
use tk::normalizers::utils::Sequence;
|
||||||
|
use tk::normalizers::NormalizerWrapper;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn serialize() {
|
||||||
|
let js_wrapped: JsNormalizerWrapper = NFKC.into();
|
||||||
|
let js_ser = serde_json::to_string(&js_wrapped).unwrap();
|
||||||
|
|
||||||
|
let rs_wrapped = NormalizerWrapper::NFKC(NFKC);
|
||||||
|
let rs_ser = serde_json::to_string(&rs_wrapped).unwrap();
|
||||||
|
assert_eq!(js_ser, rs_ser);
|
||||||
|
|
||||||
|
// let js_norm: Normalizer = serde_json::from_str(&rs_ser).unwrap();
|
||||||
|
// match js_norm.normalizer.unwrap() {
|
||||||
|
// JsNormalizerWrapper::Wrapped(nfc) => match nfc.as_ref() {
|
||||||
|
// NormalizerWrapper::NFKC(_) => {}
|
||||||
|
// _ => panic!("Expected NFKC"),
|
||||||
|
// },
|
||||||
|
// _ => panic!("Expected wrapped, not sequence."),
|
||||||
|
// }
|
||||||
|
|
||||||
|
let js_seq: JsNormalizerWrapper = Sequence::new(vec![NFC.into(), NFKC.into()]).into();
|
||||||
|
let js_wrapper_ser = serde_json::to_string(&js_seq).unwrap();
|
||||||
|
let rs_wrapped =
|
||||||
|
NormalizerWrapper::Sequence(Sequence::new(vec![NFC.into(), NFKC.into()]).into());
|
||||||
|
let rs_ser = serde_json::to_string(&rs_wrapped).unwrap();
|
||||||
|
assert_eq!(js_wrapper_ser, rs_ser);
|
||||||
|
|
||||||
|
let js_seq = Normalizer {
|
||||||
|
normalizer: Some(js_seq),
|
||||||
|
};
|
||||||
|
let js_ser = serde_json::to_string(&js_seq).unwrap();
|
||||||
|
assert_eq!(js_wrapper_ser, js_ser);
|
||||||
|
|
||||||
|
let rs_seq = Sequence::new(vec![NFC.into(), NFKC.into()]);
|
||||||
|
let rs_ser = serde_json::to_string(&rs_seq).unwrap();
|
||||||
|
assert_eq!(js_wrapper_ser, rs_ser);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -4,22 +4,63 @@ use crate::extraction::*;
|
|||||||
use neon::prelude::*;
|
use neon::prelude::*;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use serde::{ser::SerializeStruct, Serialize, Serializer};
|
||||||
use tk::pre_tokenizers::PreTokenizerWrapper;
|
use tk::pre_tokenizers::PreTokenizerWrapper;
|
||||||
use tk::PreTokenizedString;
|
use tk::PreTokenizedString;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize)]
|
||||||
|
pub enum JsPreTokenizerWrapper {
|
||||||
|
Sequence(Vec<Arc<PreTokenizerWrapper>>),
|
||||||
|
Wrapped(Arc<PreTokenizerWrapper>),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Serialize for JsPreTokenizerWrapper {
|
||||||
|
fn serialize<S>(&self, serializer: S) -> Result<<S as Serializer>::Ok, <S as Serializer>::Error>
|
||||||
|
where
|
||||||
|
S: Serializer,
|
||||||
|
{
|
||||||
|
match self {
|
||||||
|
JsPreTokenizerWrapper::Sequence(seq) => {
|
||||||
|
let mut ser = serializer.serialize_struct("Sequence", 2)?;
|
||||||
|
ser.serialize_field("type", "Sequence")?;
|
||||||
|
ser.serialize_field("pretokenizers", seq)?;
|
||||||
|
ser.end()
|
||||||
|
}
|
||||||
|
JsPreTokenizerWrapper::Wrapped(inner) => inner.serialize(serializer),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<I> From<I> for JsPreTokenizerWrapper
|
||||||
|
where
|
||||||
|
I: Into<PreTokenizerWrapper>,
|
||||||
|
{
|
||||||
|
fn from(norm: I) -> Self {
|
||||||
|
JsPreTokenizerWrapper::Wrapped(Arc::new(norm.into()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// PreTokenizers
|
/// PreTokenizers
|
||||||
#[derive(Clone, Serialize, Deserialize, Debug)]
|
#[derive(Clone, Serialize, Deserialize, Debug)]
|
||||||
pub struct PreTokenizer {
|
pub struct PreTokenizer {
|
||||||
#[serde(flatten)]
|
#[serde(flatten)]
|
||||||
pub pretok: Option<Arc<PreTokenizerWrapper>>,
|
pub pretok: Option<JsPreTokenizerWrapper>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl tk::PreTokenizer for PreTokenizer {
|
impl tk::PreTokenizer for PreTokenizer {
|
||||||
fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> tk::Result<()> {
|
fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> tk::Result<()> {
|
||||||
self.pretok
|
match self.pretok.as_ref().ok_or("Uninitialized PreTokenizer")? {
|
||||||
.as_ref()
|
JsPreTokenizerWrapper::Sequence(seq) => {
|
||||||
.ok_or("Uninitialized PreTokenizer")?
|
for pretokenizer in seq {
|
||||||
.pre_tokenize(pretokenized)
|
pretokenizer.pre_tokenize(pretokenized)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
JsPreTokenizerWrapper::Wrapped(pretokenizer) => {
|
||||||
|
pretokenizer.pre_tokenize(pretokenized)?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -41,7 +82,7 @@ fn byte_level(mut cx: FunctionContext) -> JsResult<JsPreTokenizer> {
|
|||||||
|
|
||||||
let mut pretok = JsPreTokenizer::new::<_, JsPreTokenizer, _>(&mut cx, vec![])?;
|
let mut pretok = JsPreTokenizer::new::<_, JsPreTokenizer, _>(&mut cx, vec![])?;
|
||||||
let guard = cx.lock();
|
let guard = cx.lock();
|
||||||
pretok.borrow_mut(&guard).pretok = Some(Arc::new(byte_level.into()));
|
pretok.borrow_mut(&guard).pretok = Some(byte_level.into());
|
||||||
Ok(pretok)
|
Ok(pretok)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -59,9 +100,8 @@ fn byte_level_alphabet(mut cx: FunctionContext) -> JsResult<JsValue> {
|
|||||||
fn whitespace(mut cx: FunctionContext) -> JsResult<JsPreTokenizer> {
|
fn whitespace(mut cx: FunctionContext) -> JsResult<JsPreTokenizer> {
|
||||||
let mut pretok = JsPreTokenizer::new::<_, JsPreTokenizer, _>(&mut cx, vec![])?;
|
let mut pretok = JsPreTokenizer::new::<_, JsPreTokenizer, _>(&mut cx, vec![])?;
|
||||||
let guard = cx.lock();
|
let guard = cx.lock();
|
||||||
pretok.borrow_mut(&guard).pretok = Some(Arc::new(
|
pretok.borrow_mut(&guard).pretok =
|
||||||
tk::pre_tokenizers::whitespace::Whitespace::default().into(),
|
Some(tk::pre_tokenizers::whitespace::Whitespace::default().into());
|
||||||
));
|
|
||||||
Ok(pretok)
|
Ok(pretok)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -69,9 +109,7 @@ fn whitespace(mut cx: FunctionContext) -> JsResult<JsPreTokenizer> {
|
|||||||
fn whitespace_split(mut cx: FunctionContext) -> JsResult<JsPreTokenizer> {
|
fn whitespace_split(mut cx: FunctionContext) -> JsResult<JsPreTokenizer> {
|
||||||
let mut pretok = JsPreTokenizer::new::<_, JsPreTokenizer, _>(&mut cx, vec![])?;
|
let mut pretok = JsPreTokenizer::new::<_, JsPreTokenizer, _>(&mut cx, vec![])?;
|
||||||
let guard = cx.lock();
|
let guard = cx.lock();
|
||||||
pretok.borrow_mut(&guard).pretok = Some(Arc::new(
|
pretok.borrow_mut(&guard).pretok = Some(tk::pre_tokenizers::whitespace::WhitespaceSplit.into());
|
||||||
tk::pre_tokenizers::whitespace::WhitespaceSplit.into(),
|
|
||||||
));
|
|
||||||
Ok(pretok)
|
Ok(pretok)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -79,8 +117,7 @@ fn whitespace_split(mut cx: FunctionContext) -> JsResult<JsPreTokenizer> {
|
|||||||
fn bert_pre_tokenizer(mut cx: FunctionContext) -> JsResult<JsPreTokenizer> {
|
fn bert_pre_tokenizer(mut cx: FunctionContext) -> JsResult<JsPreTokenizer> {
|
||||||
let mut pretok = JsPreTokenizer::new::<_, JsPreTokenizer, _>(&mut cx, vec![])?;
|
let mut pretok = JsPreTokenizer::new::<_, JsPreTokenizer, _>(&mut cx, vec![])?;
|
||||||
let guard = cx.lock();
|
let guard = cx.lock();
|
||||||
pretok.borrow_mut(&guard).pretok =
|
pretok.borrow_mut(&guard).pretok = Some(tk::pre_tokenizers::bert::BertPreTokenizer.into());
|
||||||
Some(Arc::new(tk::pre_tokenizers::bert::BertPreTokenizer.into()));
|
|
||||||
Ok(pretok)
|
Ok(pretok)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -91,9 +128,8 @@ fn metaspace(mut cx: FunctionContext) -> JsResult<JsPreTokenizer> {
|
|||||||
|
|
||||||
let mut pretok = JsPreTokenizer::new::<_, JsPreTokenizer, _>(&mut cx, vec![])?;
|
let mut pretok = JsPreTokenizer::new::<_, JsPreTokenizer, _>(&mut cx, vec![])?;
|
||||||
let guard = cx.lock();
|
let guard = cx.lock();
|
||||||
pretok.borrow_mut(&guard).pretok = Some(Arc::new(
|
pretok.borrow_mut(&guard).pretok =
|
||||||
tk::pre_tokenizers::metaspace::Metaspace::new(replacement, add_prefix_space).into(),
|
Some(tk::pre_tokenizers::metaspace::Metaspace::new(replacement, add_prefix_space).into());
|
||||||
));
|
|
||||||
Ok(pretok)
|
Ok(pretok)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -101,9 +137,7 @@ fn metaspace(mut cx: FunctionContext) -> JsResult<JsPreTokenizer> {
|
|||||||
fn punctuation(mut cx: FunctionContext) -> JsResult<JsPreTokenizer> {
|
fn punctuation(mut cx: FunctionContext) -> JsResult<JsPreTokenizer> {
|
||||||
let mut pretok = JsPreTokenizer::new::<_, JsPreTokenizer, _>(&mut cx, vec![])?;
|
let mut pretok = JsPreTokenizer::new::<_, JsPreTokenizer, _>(&mut cx, vec![])?;
|
||||||
let guard = cx.lock();
|
let guard = cx.lock();
|
||||||
pretok.borrow_mut(&guard).pretok = Some(Arc::new(
|
pretok.borrow_mut(&guard).pretok = Some(tk::pre_tokenizers::punctuation::Punctuation.into());
|
||||||
tk::pre_tokenizers::punctuation::Punctuation.into(),
|
|
||||||
));
|
|
||||||
Ok(pretok)
|
Ok(pretok)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -118,13 +152,15 @@ fn sequence(mut cx: FunctionContext) -> JsResult<JsPreTokenizer> {
|
|||||||
|pretokenizer| match pretokenizer.downcast::<JsPreTokenizer>().or_throw(&mut cx) {
|
|pretokenizer| match pretokenizer.downcast::<JsPreTokenizer>().or_throw(&mut cx) {
|
||||||
Ok(pretokenizer) => {
|
Ok(pretokenizer) => {
|
||||||
let guard = cx.lock();
|
let guard = cx.lock();
|
||||||
let pretokenizer = (*pretokenizer.borrow(&guard)).pretok.clone();
|
let pretok = (*pretokenizer.borrow(&guard)).pretok.clone();
|
||||||
if let Some(pretokenizer) = pretokenizer {
|
if let Some(pretokenizer) = pretok {
|
||||||
let pretok = (*pretokenizer).clone();
|
match pretokenizer {
|
||||||
sequence.push(pretok);
|
JsPreTokenizerWrapper::Sequence(seq) => sequence.extend(seq),
|
||||||
|
JsPreTokenizerWrapper::Wrapped(inner) => sequence.push(inner),
|
||||||
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
} else {
|
} else {
|
||||||
cx.throw_error("Uninitialized Normalizer")
|
cx.throw_error("Uninitialized PreTokenizer")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(e) => Err(e),
|
Err(e) => Err(e),
|
||||||
@ -133,9 +169,7 @@ fn sequence(mut cx: FunctionContext) -> JsResult<JsPreTokenizer> {
|
|||||||
.collect::<NeonResult<_>>()?;
|
.collect::<NeonResult<_>>()?;
|
||||||
let mut pretok = JsPreTokenizer::new::<_, JsPreTokenizer, _>(&mut cx, vec![])?;
|
let mut pretok = JsPreTokenizer::new::<_, JsPreTokenizer, _>(&mut cx, vec![])?;
|
||||||
let guard = cx.lock();
|
let guard = cx.lock();
|
||||||
pretok.borrow_mut(&guard).pretok = Some(Arc::new(
|
pretok.borrow_mut(&guard).pretok = Some(JsPreTokenizerWrapper::Sequence(sequence).into());
|
||||||
tk::pre_tokenizers::sequence::Sequence::new(sequence).into(),
|
|
||||||
));
|
|
||||||
Ok(pretok)
|
Ok(pretok)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -145,9 +179,8 @@ fn char_delimiter_split(mut cx: FunctionContext) -> JsResult<JsPreTokenizer> {
|
|||||||
|
|
||||||
let mut pretok = JsPreTokenizer::new::<_, JsPreTokenizer, _>(&mut cx, vec![])?;
|
let mut pretok = JsPreTokenizer::new::<_, JsPreTokenizer, _>(&mut cx, vec![])?;
|
||||||
let guard = cx.lock();
|
let guard = cx.lock();
|
||||||
pretok.borrow_mut(&guard).pretok = Some(Arc::new(
|
pretok.borrow_mut(&guard).pretok =
|
||||||
tk::pre_tokenizers::delimiter::CharDelimiterSplit::new(delimiter).into(),
|
Some(tk::pre_tokenizers::delimiter::CharDelimiterSplit::new(delimiter).into());
|
||||||
));
|
|
||||||
|
|
||||||
Ok(pretok)
|
Ok(pretok)
|
||||||
}
|
}
|
||||||
@ -171,3 +204,40 @@ pub fn register(m: &mut ModuleContext, prefix: &str) -> NeonResult<()> {
|
|||||||
m.export_function(&format!("{}_Sequence", prefix), sequence)?;
|
m.export_function(&format!("{}_Sequence", prefix), sequence)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod test {
|
||||||
|
use super::*;
|
||||||
|
use tk::pre_tokenizers::sequence::Sequence;
|
||||||
|
use tk::pre_tokenizers::whitespace::{Whitespace, WhitespaceSplit};
|
||||||
|
use tk::pre_tokenizers::PreTokenizerWrapper;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn serialize() {
|
||||||
|
let js_wrapped: JsPreTokenizerWrapper = Whitespace::default().into();
|
||||||
|
let js_ser = serde_json::to_string(&js_wrapped).unwrap();
|
||||||
|
|
||||||
|
let rs_wrapped = PreTokenizerWrapper::Whitespace(Whitespace::default());
|
||||||
|
let rs_ser = serde_json::to_string(&rs_wrapped).unwrap();
|
||||||
|
assert_eq!(js_ser, rs_ser);
|
||||||
|
|
||||||
|
let js_seq: JsPreTokenizerWrapper =
|
||||||
|
Sequence::new(vec![WhitespaceSplit.into(), Whitespace::default().into()]).into();
|
||||||
|
let js_wrapper_ser = serde_json::to_string(&js_seq).unwrap();
|
||||||
|
let rs_wrapped = PreTokenizerWrapper::Sequence(
|
||||||
|
Sequence::new(vec![WhitespaceSplit.into(), Whitespace::default().into()]).into(),
|
||||||
|
);
|
||||||
|
let rs_ser = serde_json::to_string(&rs_wrapped).unwrap();
|
||||||
|
assert_eq!(js_wrapper_ser, rs_ser);
|
||||||
|
|
||||||
|
let js_seq = PreTokenizer {
|
||||||
|
pretok: Some(js_seq),
|
||||||
|
};
|
||||||
|
let js_ser = serde_json::to_string(&js_seq).unwrap();
|
||||||
|
assert_eq!(js_wrapper_ser, js_ser);
|
||||||
|
|
||||||
|
let rs_seq = Sequence::new(vec![WhitespaceSplit.into(), Whitespace::default().into()]);
|
||||||
|
let rs_ser = serde_json::to_string(&rs_seq).unwrap();
|
||||||
|
assert_eq!(js_wrapper_ser, rs_ser);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user