mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-23 00:35:35 +00:00
Adding a test to check that serialization of JsPreTokenizer fits
PreTokenizer serialization.
This commit is contained in:
committed by
Anthony MOI
parent
52e527dc9f
commit
b8d6c1dece
@ -12,7 +12,7 @@ import { EncodeOptions, InputSequence, Tokenizer } from "./tokenizer";
|
||||
|
||||
const MOCKS_DIR = __dirname + "/__mocks__";
|
||||
|
||||
describe("Encoding", () => {
|
||||
describe("Can modify pretokenizers on the fly", () => {
|
||||
let encoding: RawEncoding;
|
||||
let encode: (
|
||||
sequence: InputSequence,
|
||||
@ -30,28 +30,22 @@ describe("Encoding", () => {
|
||||
);
|
||||
|
||||
tokenizer = new Tokenizer(model);
|
||||
tokenizer.setPreTokenizer(whitespacePreTokenizer());
|
||||
encode = promisify(tokenizer.encode.bind(tokenizer));
|
||||
});
|
||||
|
||||
it("Encodes correctly", async () => {
|
||||
encoding = await encode("my name is john", null);
|
||||
expect(encoding.getIds()).toEqual([0, 1, 2, 3, 4]);
|
||||
|
||||
encoding = await encode("my name is john", null);
|
||||
expect(encoding.getIds()).toEqual([0, 1, 2, 3, 4]);
|
||||
});
|
||||
|
||||
it("Can change pre tokenizer", async () => {
|
||||
const input = "my name is john.!?";
|
||||
tokenizer.setPreTokenizer(sequencePreTokenizer([whitespacePreTokenizer()]));
|
||||
|
||||
encoding = await encode("my name is john.!?", null);
|
||||
encoding = await encode(input, null);
|
||||
expect(encoding.getIds()).toEqual([0, 1, 2, 3, 4, 6]);
|
||||
|
||||
// Change pre tokenizer
|
||||
tokenizer.setPreTokenizer(
|
||||
sequencePreTokenizer([whitespacePreTokenizer(), punctuationPreTokenizer()])
|
||||
);
|
||||
|
||||
encoding = await encode("my name is john.!?", null);
|
||||
encoding = await encode(input, null);
|
||||
expect(encoding.getIds()).toEqual([0, 1, 2, 3, 4, 6, 6, 6]);
|
||||
});
|
||||
});
|
||||
|
1
bindings/node/native/Cargo.lock
generated
1
bindings/node/native/Cargo.lock
generated
@ -448,6 +448,7 @@ dependencies = [
|
||||
"neon-runtime",
|
||||
"neon-serde",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tokenizers",
|
||||
]
|
||||
|
||||
|
@ -19,3 +19,4 @@ neon-runtime = "0.3"
|
||||
neon-serde = "0.3"
|
||||
serde = { version = "1.0", features = [ "rc", "derive" ] }
|
||||
tokenizers = { path = "../../../tokenizers" }
|
||||
serde_json = "1.0"
|
||||
|
@ -8,7 +8,7 @@ use std::sync::Arc;
|
||||
use tk::normalizers::NormalizerWrapper;
|
||||
use tk::NormalizedString;
|
||||
|
||||
#[derive(Clone, Deserialize)]
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
pub enum JsNormalizerWrapper {
|
||||
Sequence(Vec<Arc<NormalizerWrapper>>),
|
||||
Wrapped(Arc<NormalizerWrapper>),
|
||||
@ -213,3 +213,47 @@ pub fn register(m: &mut ModuleContext, prefix: &str) -> NeonResult<()> {
|
||||
m.export_function(&format!("{}_Strip", prefix), strip)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
use tk::normalizers::unicode::{NFC, NFKC};
|
||||
use tk::normalizers::utils::Sequence;
|
||||
use tk::normalizers::NormalizerWrapper;
|
||||
|
||||
#[test]
|
||||
fn serialize() {
|
||||
let js_wrapped: JsNormalizerWrapper = NFKC.into();
|
||||
let js_ser = serde_json::to_string(&js_wrapped).unwrap();
|
||||
|
||||
let rs_wrapped = NormalizerWrapper::NFKC(NFKC);
|
||||
let rs_ser = serde_json::to_string(&rs_wrapped).unwrap();
|
||||
assert_eq!(js_ser, rs_ser);
|
||||
|
||||
// let js_norm: Normalizer = serde_json::from_str(&rs_ser).unwrap();
|
||||
// match js_norm.normalizer.unwrap() {
|
||||
// JsNormalizerWrapper::Wrapped(nfc) => match nfc.as_ref() {
|
||||
// NormalizerWrapper::NFKC(_) => {}
|
||||
// _ => panic!("Expected NFKC"),
|
||||
// },
|
||||
// _ => panic!("Expected wrapped, not sequence."),
|
||||
// }
|
||||
|
||||
let js_seq: JsNormalizerWrapper = Sequence::new(vec![NFC.into(), NFKC.into()]).into();
|
||||
let js_wrapper_ser = serde_json::to_string(&js_seq).unwrap();
|
||||
let rs_wrapped =
|
||||
NormalizerWrapper::Sequence(Sequence::new(vec![NFC.into(), NFKC.into()]).into());
|
||||
let rs_ser = serde_json::to_string(&rs_wrapped).unwrap();
|
||||
assert_eq!(js_wrapper_ser, rs_ser);
|
||||
|
||||
let js_seq = Normalizer {
|
||||
normalizer: Some(js_seq),
|
||||
};
|
||||
let js_ser = serde_json::to_string(&js_seq).unwrap();
|
||||
assert_eq!(js_wrapper_ser, js_ser);
|
||||
|
||||
let rs_seq = Sequence::new(vec![NFC.into(), NFKC.into()]);
|
||||
let rs_ser = serde_json::to_string(&rs_seq).unwrap();
|
||||
assert_eq!(js_wrapper_ser, rs_ser);
|
||||
}
|
||||
}
|
||||
|
@ -4,22 +4,63 @@ use crate::extraction::*;
|
||||
use neon::prelude::*;
|
||||
use std::sync::Arc;
|
||||
|
||||
use serde::{ser::SerializeStruct, Serialize, Serializer};
|
||||
use tk::pre_tokenizers::PreTokenizerWrapper;
|
||||
use tk::PreTokenizedString;
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
pub enum JsPreTokenizerWrapper {
|
||||
Sequence(Vec<Arc<PreTokenizerWrapper>>),
|
||||
Wrapped(Arc<PreTokenizerWrapper>),
|
||||
}
|
||||
|
||||
impl Serialize for JsPreTokenizerWrapper {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<<S as Serializer>::Ok, <S as Serializer>::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
match self {
|
||||
JsPreTokenizerWrapper::Sequence(seq) => {
|
||||
let mut ser = serializer.serialize_struct("Sequence", 2)?;
|
||||
ser.serialize_field("type", "Sequence")?;
|
||||
ser.serialize_field("pretokenizers", seq)?;
|
||||
ser.end()
|
||||
}
|
||||
JsPreTokenizerWrapper::Wrapped(inner) => inner.serialize(serializer),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<I> From<I> for JsPreTokenizerWrapper
|
||||
where
|
||||
I: Into<PreTokenizerWrapper>,
|
||||
{
|
||||
fn from(norm: I) -> Self {
|
||||
JsPreTokenizerWrapper::Wrapped(Arc::new(norm.into()))
|
||||
}
|
||||
}
|
||||
|
||||
/// PreTokenizers
|
||||
#[derive(Clone, Serialize, Deserialize, Debug)]
|
||||
pub struct PreTokenizer {
|
||||
#[serde(flatten)]
|
||||
pub pretok: Option<Arc<PreTokenizerWrapper>>,
|
||||
pub pretok: Option<JsPreTokenizerWrapper>,
|
||||
}
|
||||
|
||||
impl tk::PreTokenizer for PreTokenizer {
|
||||
fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> tk::Result<()> {
|
||||
self.pretok
|
||||
.as_ref()
|
||||
.ok_or("Uninitialized PreTokenizer")?
|
||||
.pre_tokenize(pretokenized)
|
||||
match self.pretok.as_ref().ok_or("Uninitialized PreTokenizer")? {
|
||||
JsPreTokenizerWrapper::Sequence(seq) => {
|
||||
for pretokenizer in seq {
|
||||
pretokenizer.pre_tokenize(pretokenized)?;
|
||||
}
|
||||
}
|
||||
JsPreTokenizerWrapper::Wrapped(pretokenizer) => {
|
||||
pretokenizer.pre_tokenize(pretokenized)?
|
||||
}
|
||||
};
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@ -41,7 +82,7 @@ fn byte_level(mut cx: FunctionContext) -> JsResult<JsPreTokenizer> {
|
||||
|
||||
let mut pretok = JsPreTokenizer::new::<_, JsPreTokenizer, _>(&mut cx, vec![])?;
|
||||
let guard = cx.lock();
|
||||
pretok.borrow_mut(&guard).pretok = Some(Arc::new(byte_level.into()));
|
||||
pretok.borrow_mut(&guard).pretok = Some(byte_level.into());
|
||||
Ok(pretok)
|
||||
}
|
||||
|
||||
@ -59,9 +100,8 @@ fn byte_level_alphabet(mut cx: FunctionContext) -> JsResult<JsValue> {
|
||||
fn whitespace(mut cx: FunctionContext) -> JsResult<JsPreTokenizer> {
|
||||
let mut pretok = JsPreTokenizer::new::<_, JsPreTokenizer, _>(&mut cx, vec![])?;
|
||||
let guard = cx.lock();
|
||||
pretok.borrow_mut(&guard).pretok = Some(Arc::new(
|
||||
tk::pre_tokenizers::whitespace::Whitespace::default().into(),
|
||||
));
|
||||
pretok.borrow_mut(&guard).pretok =
|
||||
Some(tk::pre_tokenizers::whitespace::Whitespace::default().into());
|
||||
Ok(pretok)
|
||||
}
|
||||
|
||||
@ -69,9 +109,7 @@ fn whitespace(mut cx: FunctionContext) -> JsResult<JsPreTokenizer> {
|
||||
fn whitespace_split(mut cx: FunctionContext) -> JsResult<JsPreTokenizer> {
|
||||
let mut pretok = JsPreTokenizer::new::<_, JsPreTokenizer, _>(&mut cx, vec![])?;
|
||||
let guard = cx.lock();
|
||||
pretok.borrow_mut(&guard).pretok = Some(Arc::new(
|
||||
tk::pre_tokenizers::whitespace::WhitespaceSplit.into(),
|
||||
));
|
||||
pretok.borrow_mut(&guard).pretok = Some(tk::pre_tokenizers::whitespace::WhitespaceSplit.into());
|
||||
Ok(pretok)
|
||||
}
|
||||
|
||||
@ -79,8 +117,7 @@ fn whitespace_split(mut cx: FunctionContext) -> JsResult<JsPreTokenizer> {
|
||||
fn bert_pre_tokenizer(mut cx: FunctionContext) -> JsResult<JsPreTokenizer> {
|
||||
let mut pretok = JsPreTokenizer::new::<_, JsPreTokenizer, _>(&mut cx, vec![])?;
|
||||
let guard = cx.lock();
|
||||
pretok.borrow_mut(&guard).pretok =
|
||||
Some(Arc::new(tk::pre_tokenizers::bert::BertPreTokenizer.into()));
|
||||
pretok.borrow_mut(&guard).pretok = Some(tk::pre_tokenizers::bert::BertPreTokenizer.into());
|
||||
Ok(pretok)
|
||||
}
|
||||
|
||||
@ -91,9 +128,8 @@ fn metaspace(mut cx: FunctionContext) -> JsResult<JsPreTokenizer> {
|
||||
|
||||
let mut pretok = JsPreTokenizer::new::<_, JsPreTokenizer, _>(&mut cx, vec![])?;
|
||||
let guard = cx.lock();
|
||||
pretok.borrow_mut(&guard).pretok = Some(Arc::new(
|
||||
tk::pre_tokenizers::metaspace::Metaspace::new(replacement, add_prefix_space).into(),
|
||||
));
|
||||
pretok.borrow_mut(&guard).pretok =
|
||||
Some(tk::pre_tokenizers::metaspace::Metaspace::new(replacement, add_prefix_space).into());
|
||||
Ok(pretok)
|
||||
}
|
||||
|
||||
@ -101,9 +137,7 @@ fn metaspace(mut cx: FunctionContext) -> JsResult<JsPreTokenizer> {
|
||||
fn punctuation(mut cx: FunctionContext) -> JsResult<JsPreTokenizer> {
|
||||
let mut pretok = JsPreTokenizer::new::<_, JsPreTokenizer, _>(&mut cx, vec![])?;
|
||||
let guard = cx.lock();
|
||||
pretok.borrow_mut(&guard).pretok = Some(Arc::new(
|
||||
tk::pre_tokenizers::punctuation::Punctuation.into(),
|
||||
));
|
||||
pretok.borrow_mut(&guard).pretok = Some(tk::pre_tokenizers::punctuation::Punctuation.into());
|
||||
Ok(pretok)
|
||||
}
|
||||
|
||||
@ -118,13 +152,15 @@ fn sequence(mut cx: FunctionContext) -> JsResult<JsPreTokenizer> {
|
||||
|pretokenizer| match pretokenizer.downcast::<JsPreTokenizer>().or_throw(&mut cx) {
|
||||
Ok(pretokenizer) => {
|
||||
let guard = cx.lock();
|
||||
let pretokenizer = (*pretokenizer.borrow(&guard)).pretok.clone();
|
||||
if let Some(pretokenizer) = pretokenizer {
|
||||
let pretok = (*pretokenizer).clone();
|
||||
sequence.push(pretok);
|
||||
let pretok = (*pretokenizer.borrow(&guard)).pretok.clone();
|
||||
if let Some(pretokenizer) = pretok {
|
||||
match pretokenizer {
|
||||
JsPreTokenizerWrapper::Sequence(seq) => sequence.extend(seq),
|
||||
JsPreTokenizerWrapper::Wrapped(inner) => sequence.push(inner),
|
||||
}
|
||||
Ok(())
|
||||
} else {
|
||||
cx.throw_error("Uninitialized Normalizer")
|
||||
cx.throw_error("Uninitialized PreTokenizer")
|
||||
}
|
||||
}
|
||||
Err(e) => Err(e),
|
||||
@ -133,9 +169,7 @@ fn sequence(mut cx: FunctionContext) -> JsResult<JsPreTokenizer> {
|
||||
.collect::<NeonResult<_>>()?;
|
||||
let mut pretok = JsPreTokenizer::new::<_, JsPreTokenizer, _>(&mut cx, vec![])?;
|
||||
let guard = cx.lock();
|
||||
pretok.borrow_mut(&guard).pretok = Some(Arc::new(
|
||||
tk::pre_tokenizers::sequence::Sequence::new(sequence).into(),
|
||||
));
|
||||
pretok.borrow_mut(&guard).pretok = Some(JsPreTokenizerWrapper::Sequence(sequence).into());
|
||||
Ok(pretok)
|
||||
}
|
||||
|
||||
@ -145,9 +179,8 @@ fn char_delimiter_split(mut cx: FunctionContext) -> JsResult<JsPreTokenizer> {
|
||||
|
||||
let mut pretok = JsPreTokenizer::new::<_, JsPreTokenizer, _>(&mut cx, vec![])?;
|
||||
let guard = cx.lock();
|
||||
pretok.borrow_mut(&guard).pretok = Some(Arc::new(
|
||||
tk::pre_tokenizers::delimiter::CharDelimiterSplit::new(delimiter).into(),
|
||||
));
|
||||
pretok.borrow_mut(&guard).pretok =
|
||||
Some(tk::pre_tokenizers::delimiter::CharDelimiterSplit::new(delimiter).into());
|
||||
|
||||
Ok(pretok)
|
||||
}
|
||||
@ -171,3 +204,40 @@ pub fn register(m: &mut ModuleContext, prefix: &str) -> NeonResult<()> {
|
||||
m.export_function(&format!("{}_Sequence", prefix), sequence)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
use tk::pre_tokenizers::sequence::Sequence;
|
||||
use tk::pre_tokenizers::whitespace::{Whitespace, WhitespaceSplit};
|
||||
use tk::pre_tokenizers::PreTokenizerWrapper;
|
||||
|
||||
#[test]
|
||||
fn serialize() {
|
||||
let js_wrapped: JsPreTokenizerWrapper = Whitespace::default().into();
|
||||
let js_ser = serde_json::to_string(&js_wrapped).unwrap();
|
||||
|
||||
let rs_wrapped = PreTokenizerWrapper::Whitespace(Whitespace::default());
|
||||
let rs_ser = serde_json::to_string(&rs_wrapped).unwrap();
|
||||
assert_eq!(js_ser, rs_ser);
|
||||
|
||||
let js_seq: JsPreTokenizerWrapper =
|
||||
Sequence::new(vec![WhitespaceSplit.into(), Whitespace::default().into()]).into();
|
||||
let js_wrapper_ser = serde_json::to_string(&js_seq).unwrap();
|
||||
let rs_wrapped = PreTokenizerWrapper::Sequence(
|
||||
Sequence::new(vec![WhitespaceSplit.into(), Whitespace::default().into()]).into(),
|
||||
);
|
||||
let rs_ser = serde_json::to_string(&rs_wrapped).unwrap();
|
||||
assert_eq!(js_wrapper_ser, rs_ser);
|
||||
|
||||
let js_seq = PreTokenizer {
|
||||
pretok: Some(js_seq),
|
||||
};
|
||||
let js_ser = serde_json::to_string(&js_seq).unwrap();
|
||||
assert_eq!(js_wrapper_ser, js_ser);
|
||||
|
||||
let rs_seq = Sequence::new(vec![WhitespaceSplit.into(), Whitespace::default().into()]);
|
||||
let rs_ser = serde_json::to_string(&rs_seq).unwrap();
|
||||
assert_eq!(js_wrapper_ser, rs_ser);
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user