Fixing node.js example.

- Now supports more lenient syntax and more aligned with python&Rust.
- Backward compatible.
This commit is contained in:
Nicolas Patry
2020-09-29 11:23:23 +02:00
committed by Anthony MOI
parent 6f8892e3ae
commit 44e8f4be8f
6 changed files with 2832 additions and 1517 deletions

View File

@ -1,18 +1,25 @@
/*eslint-disable no-undef*/ /*eslint-disable no-undef*/
const tokenizers = require("tokenizers"); const tokenizers = require("tokenizers");
const { promisify } = require("util");
describe("loadExample", () => { describe("loadExample", () => {
it("", () => { beforeAll(async () => {});
it("", async () => {
const example = "This is an example"; const example = "This is an example";
const ids = [713, 16, 41, 1246]; const ids = [713, 16, 41, 1246];
const tokens = ["This", "Ġis", "Ġan", "Ġexample"]; const tokens = ["This", "Ġis", "Ġan", "Ġexample"];
const tokenizer = tokenizers.Tokenizer.fromFile("data/roberta.json"); const tokenizer = tokenizers.Tokenizer.fromFile("data/roberta.json");
const encoded = tokenizer.encode(example);
expect(encoded.ids).toBe(ids); // You could also use regular callbacks
expect(encoded.tokens).toBe(tokens); const encode = promisify(tokenizer.encode.bind(tokenizer));
const decode = promisify(tokenizer.decode.bind(tokenizer));
expect(tokenizer.decode(ids)).toBe(example); const encoded = await encode(example);
expect(encoded.getIds()).toEqual(ids);
expect(encoded.getTokens()).toEqual(tokens);
const decoded = await decode(ids);
expect(decoded).toEqual(example);
}); });
}); });

View File

@ -138,7 +138,7 @@ describe("Tokenizer", () => {
}); });
it("accepts a pair of strings as parameters", async () => { it("accepts a pair of strings as parameters", async () => {
const encoding = await encode("my name is john", "pair", undefined); const encoding = await encode("my name is john", "pair");
expect(encoding).toBeDefined(); expect(encoding).toBeDefined();
}); });
@ -174,10 +174,9 @@ describe("Tokenizer", () => {
expect(encoding).toBeDefined(); expect(encoding).toBeDefined();
}); });
it("throws if called with only one argument", async () => { it("Encodes correctly if called with only one argument", async () => {
await expect((encode as any)("my name is john")).rejects.toThrow( const encoded = await encode("my name is john");
"not enough arguments" expect(encoded.getIds()).toEqual([0, 1, 2, 3]);
);
}); });
it("returns an Encoding", async () => { it("returns an Encoding", async () => {

View File

@ -66,7 +66,6 @@ impl Task for EncodeTask<'static> {
// Set the actual encoding // Set the actual encoding
let guard = cx.lock(); let guard = cx.lock();
js_encoding.borrow_mut(&guard).encoding = Some(encoding); js_encoding.borrow_mut(&guard).encoding = Some(encoding);
Ok(js_encoding.upcast()) Ok(js_encoding.upcast())
} }
EncodeOutput::Batch(encodings) => { EncodeOutput::Batch(encodings) => {

View File

@ -449,23 +449,19 @@ declare_types! {
// __callback: (err, encoding) -> void // __callback: (err, encoding) -> void
// ) // )
// Start by extracting options and callback // Start by extracting options if they exist (options is in slot 1 ,or 2)
let (options, callback) = match cx.extract_opt::<EncodeOptions>(2) { let mut i = 1;
// Options were there, and extracted let (options, option_index) = loop {
Ok(Some(options)) => { if let Ok(Some(opts)) = cx.extract_opt::<EncodeOptions>(i){
(options, cx.argument::<JsFunction>(3)?) break (opts, Some(i));
},
// Options were undefined or null
Ok(None) => {
(EncodeOptions::default(), cx.argument::<JsFunction>(3)?)
} }
// Options not specified, callback instead i += 1;
Err(_) => { if i == 3{
(EncodeOptions::default(), cx.argument::<JsFunction>(2)?) break (EncodeOptions::default(), None)
} }
}; };
// Then we extract our input sequences // Then we extract the first input sentence
let sentence: tk::InputSequence = if options.is_pretokenized { let sentence: tk::InputSequence = if options.is_pretokenized {
cx.extract::<PreTokenizedInputSequence>(0) cx.extract::<PreTokenizedInputSequence>(0)
.map_err(|_| Error("encode with isPretokenized=true expect string[]".into()))? .map_err(|_| Error("encode with isPretokenized=true expect string[]".into()))?
@ -475,15 +471,29 @@ declare_types! {
.map_err(|_| Error("encode with isPreTokenized=false expect string".into()))? .map_err(|_| Error("encode with isPreTokenized=false expect string".into()))?
.into() .into()
}; };
let pair: Option<tk::InputSequence> = if options.is_pretokenized {
cx.extract_opt::<PreTokenizedInputSequence>(1) let (pair, has_pair_arg): (Option<tk::InputSequence>, bool) = if options.is_pretokenized {
.map_err(|_| Error("encode with isPretokenized=true expect string[]".into()))? if let Ok(second) = cx.extract_opt::<PreTokenizedInputSequence>(1){
.map(|v| v.into()) (second.map(|v| v.into()), true)
}else{ }else{
cx.extract_opt::<TextInputSequence>(1) (None, false)
.map_err(|_| Error("encode with isPreTokenized=false expect string".into()))? }
.map(|v| v.into()) } else if let Ok(second) = cx.extract_opt::<TextInputSequence>(1){
(second.map(|v| v.into()), true)
}else{
(None, false)
}; };
// Find the callback index.
let last_index = if let Some(option_index) = option_index{
option_index + 1
}else if has_pair_arg{
2
}else{
1
};
let callback = cx.argument::<JsFunction>(last_index)?;
let input: tk::EncodeInput = match pair { let input: tk::EncodeInput = match pair {
Some(pair) => (sentence, pair).into(), Some(pair) => (sentence, pair).into(),
None => sentence.into() None => sentence.into()
@ -557,8 +567,12 @@ declare_types! {
// decode(ids: number[], skipSpecialTokens: bool, callback) // decode(ids: number[], skipSpecialTokens: bool, callback)
let ids = cx.extract_vec::<u32>(0)?; let ids = cx.extract_vec::<u32>(0)?;
let skip_special_tokens = cx.extract::<bool>(1)?; let (skip_special_tokens, callback_index) = if let Ok(skip_special_tokens) = cx.extract::<bool>(1){
let callback = cx.argument::<JsFunction>(2)?; (skip_special_tokens, 2)
}else{
(false, 1)
};
let callback = cx.argument::<JsFunction>(callback_index)?;
let this = cx.this(); let this = cx.this();
let guard = cx.lock(); let guard = cx.lock();
@ -575,8 +589,12 @@ declare_types! {
// decodeBatch(sequences: number[][], skipSpecialTokens: bool, callback) // decodeBatch(sequences: number[][], skipSpecialTokens: bool, callback)
let sentences = cx.extract_vec::<Vec<u32>>(0)?; let sentences = cx.extract_vec::<Vec<u32>>(0)?;
let skip_special_tokens = cx.extract::<bool>(1)?; let (skip_special_tokens, callback_index) = if let Ok(skip_special_tokens) = cx.extract::<bool>(1){
let callback = cx.argument::<JsFunction>(2)?; (skip_special_tokens, 2)
}else{
(false, 1)
};
let callback = cx.argument::<JsFunction>(callback_index)?;
let this = cx.this(); let this = cx.this();
let guard = cx.lock(); let guard = cx.lock();
@ -956,7 +974,8 @@ pub fn tokenizer_from_string(mut cx: FunctionContext) -> JsResult<JsTokenizer> {
Decoder, Decoder,
> = s.parse().map_err(|e| Error(format!("{}", e)))?; > = s.parse().map_err(|e| Error(format!("{}", e)))?;
let mut js_tokenizer = JsTokenizer::new::<_, JsTokenizer, _>(&mut cx, vec![])?; let js_model: Handle<JsModel> = JsModel::new::<_, JsModel, _>(&mut cx, vec![])?;
let mut js_tokenizer = JsTokenizer::new(&mut cx, vec![js_model])?;
let guard = cx.lock(); let guard = cx.lock();
js_tokenizer.borrow_mut(&guard).tokenizer = Arc::new(RwLock::new(tokenizer)); js_tokenizer.borrow_mut(&guard).tokenizer = Arc::new(RwLock::new(tokenizer));
@ -969,7 +988,8 @@ pub fn tokenizer_from_file(mut cx: FunctionContext) -> JsResult<JsTokenizer> {
let tokenizer = tk::tokenizer::TokenizerImpl::from_file(s) let tokenizer = tk::tokenizer::TokenizerImpl::from_file(s)
.map_err(|e| Error(format!("Error loading from file{}", e)))?; .map_err(|e| Error(format!("Error loading from file{}", e)))?;
let mut js_tokenizer = JsTokenizer::new::<_, JsTokenizer, _>(&mut cx, vec![])?; let js_model: Handle<JsModel> = JsModel::new::<_, JsModel, _>(&mut cx, vec![])?;
let mut js_tokenizer = JsTokenizer::new(&mut cx, vec![js_model])?;
let guard = cx.lock(); let guard = cx.lock();
js_tokenizer.borrow_mut(&guard).tokenizer = Arc::new(RwLock::new(tokenizer)); js_tokenizer.borrow_mut(&guard).tokenizer = Arc::new(RwLock::new(tokenizer));

File diff suppressed because it is too large Load Diff

View File

@ -15,24 +15,24 @@
"author": "Anthony MOI <m.anthony.moi@gmail.com>", "author": "Anthony MOI <m.anthony.moi@gmail.com>",
"license": "Apache-2.0", "license": "Apache-2.0",
"dependencies": { "dependencies": {
"@types/node": "^13.13.21", "@types/node": "^13.1.6",
"node-pre-gyp": "^0.14.0" "node-pre-gyp": "^0.14.0"
}, },
"devDependencies": { "devDependencies": {
"@types/jest": "^26.0.14", "@types/jest": "^26.0.7",
"@typescript-eslint/eslint-plugin": "^3.10.1", "@typescript-eslint/eslint-plugin": "^3.7.0",
"@typescript-eslint/parser": "^3.10.1", "@typescript-eslint/parser": "^3.7.0",
"eslint": "^7.10.0", "eslint": "^7.5.0",
"eslint-config-prettier": "^6.12.0", "eslint-config-prettier": "^6.11.0",
"eslint-plugin-jest": "^23.20.0", "eslint-plugin-jest": "^23.18.0",
"eslint-plugin-jsdoc": "^30.6.1", "eslint-plugin-jsdoc": "^30.0.3",
"eslint-plugin-prettier": "^3.1.4", "eslint-plugin-prettier": "^3.1.4",
"eslint-plugin-simple-import-sort": "^5.0.3", "eslint-plugin-simple-import-sort": "^5.0.3",
"jest": "^26.4.2", "jest": "^26.1.0",
"neon-cli": "^0.4.2", "neon-cli": "^0.3.3",
"prettier": "^2.1.2", "prettier": "^2.0.5",
"shelljs": "^0.8.3", "shelljs": "^0.8.3",
"ts-jest": "^26.4.0", "ts-jest": "^26.1.3",
"typescript": "^3.9.7" "typescript": "^3.9.7"
}, },
"engines": { "engines": {