Node - Merge encodings

This commit is contained in:
Pierric Cistac
2020-03-26 13:25:28 -04:00
committed by Anthony MOI
parent 4341c79d85
commit 0408567f23
6 changed files with 107 additions and 6 deletions

View File

@ -1,3 +1,5 @@
import { RawEncoding } from "./raw-encoding";
/**
* Returns a subpart of a string according to specified indexes, and respecting unicode characters
*
@ -10,3 +12,13 @@
* @since 0.6.0
*/
export function slice(text: string, start?: number, end?: number): string;
/**
* Merge the list of RawEncoding into one final RawEncoding
* @param encodings The list of encodings to merge
* @param [growingOffsets=false] Whether the offsets should accumulate while merging
*/
export function mergeEncodings(
encodings: RawEncoding[],
growingOffsets?: boolean
): RawEncoding;

View File

@ -1,5 +1,6 @@
const native = require("./native");
module.exports = {
mergeEncodings: native.utils_mergeEncodings,
slice: native.utils_slice
};

View File

@ -1,4 +1,9 @@
import { slice } from "./utils";
import { promisify } from "util";
import { BPE } from "./models";
import { RawEncoding } from "./raw-encoding";
import { Tokenizer } from "./tokenizer";
import { mergeEncodings, slice } from "./utils";
describe("slice", () => {
const text = "My name is John 👋";
@ -105,3 +110,66 @@ describe("slice", () => {
});
});
});
describe("mergeEncodings", () => {
let encode: (
sequence: string,
pair: string | null,
addSpecialTokens: boolean
) => Promise<RawEncoding>;
beforeEach(async () => {
const model = BPE.empty();
const tokenizer = new Tokenizer(model);
tokenizer.addTokens(["my", "name", "is", "john"]);
encode = promisify(tokenizer.encode.bind(tokenizer));
});
it("accepts `undefined` as a second parameter", () => {
const encoding = mergeEncodings([], undefined);
expect(encoding.constructor.name).toEqual("Encoding");
});
it("returns correct result with `growingOffsets` not provided", async () => {
const firstEncoding = await encode("my name is", null, false);
const secondEncoding = await encode("john", null, false);
const encoding = mergeEncodings([firstEncoding, secondEncoding]);
expect(encoding.getTokens()).toEqual(["my", "name", "is", "john"]);
expect(encoding.getOffsets()).toEqual([
[0, 2],
[3, 7],
[8, 10],
[0, 4]
]);
});
it("returns correct result when `growingOffsets` is `false`", async () => {
const firstEncoding = await encode("my name is", null, false);
const secondEncoding = await encode("john", null, false);
const encoding = mergeEncodings([firstEncoding, secondEncoding], false);
expect(encoding.getTokens()).toEqual(["my", "name", "is", "john"]);
expect(encoding.getOffsets()).toEqual([
[0, 2],
[3, 7],
[8, 10],
[0, 4]
]);
});
it("returns correct result when `growingOffsets` is `true`", async () => {
const firstEncoding = await encode("my name is", null, false);
const secondEncoding = await encode("john", null, false);
const encoding = mergeEncodings([firstEncoding, secondEncoding], true);
expect(encoding.getTokens()).toEqual(["my", "name", "is", "john"]);
expect(encoding.getOffsets()).toEqual([
[0, 2],
[3, 7],
[8, 10],
[10, 14]
]);
});
});

View File

@ -1,4 +1,5 @@
import { PaddingOptions, RawEncoding } from "../bindings/raw-encoding";
import { mergeEncodings } from "../bindings/utils";
export class Encoding {
private _attentionMask?: number[];
@ -13,6 +14,20 @@ export class Encoding {
constructor(private rawEncoding: RawEncoding) {}
/**
* Merge a list of Encoding into one final Encoding
* @param encodings The list of encodings to merge
* @param [growingOffsets=false] Whether the offsets should accumulate while merging
*/
static merge(encodings: Encoding[], growingOffsets?: boolean): Encoding {
const mergedRaw = mergeEncodings(
encodings.map(e => e.rawEncoding),
growingOffsets
);
return new Encoding(mergedRaw);
}
/**
* Attention mask
*/

View File

@ -1,6 +1,5 @@
import { promisify } from "util";
import { RawEncoding } from "../../bindings/raw-encoding";
import {
AddedToken,
PaddingConfiguration,

View File

@ -64,10 +64,16 @@ fn merge_encodings(mut cx: FunctionContext) -> JsResult<JsEncoding> {
})
.collect::<Result<Vec<_>, neon::result::Throw>>()
.map_err(|e| cx.throw_error::<_, ()>(format!("{}", e)).unwrap_err())?;
let growing_offsets = cx
.argument_opt(1)
.map(|arg| Ok(arg.downcast::<JsBoolean>().or_throw(&mut cx)?.value()))
.unwrap_or(Ok(false))?;
let growing_offsets = if let Some(arg) = cx.argument_opt(1) {
if arg.downcast::<JsUndefined>().is_err() {
arg.downcast::<JsBoolean>().or_throw(&mut cx)?.value()
} else {
false
}
} else {
false
};
let new_encoding = tk::tokenizer::Encoding::merge(encodings.as_slice(), growing_offsets);
let mut js_encoding = JsEncoding::new::<_, JsEncoding, _>(&mut cx, vec![])?;