node: add binding for original string

This commit is contained in:
Pierric Cistac
2020-01-31 18:09:05 -05:00
committed by Anthony MOI
parent e9ecd5aeec
commit a14c63343b
3 changed files with 177 additions and 0 deletions

View File

@ -0,0 +1,124 @@
import { promisify } from "util";
import { BPE } from "./models";
import { Encoding, Tokenizer } from "./tokenizer";
describe("Encoding", () => {
const originalString = "my name is john";
let encoding: Encoding;
beforeEach(async () => {
const model = BPE.empty();
const tokenizer = new Tokenizer(model);
tokenizer.addTokens(["my", "name", "is", "john", "pair"]);
const encode = promisify(tokenizer.encode.bind(tokenizer));
encoding = await encode(originalString, null);
});
it("has a list of defined methods", async () => {
expect(typeof encoding.getAttentionMask).toBe("function");
expect(typeof encoding.getIds).toBe("function");
expect(typeof encoding.getOffsets).toBe("function");
expect(typeof encoding.getOriginalString).toBe("function");
expect(typeof encoding.getOverflowing).toBe("function");
expect(typeof encoding.getSpecialTokensMask).toBe("function");
expect(typeof encoding.getTokens).toBe("function");
expect(typeof encoding.getTypeIds).toBe("function");
expect(typeof encoding.pad).toBe("function");
expect(typeof encoding.truncate).toBe("function");
});
describe("getOriginalString", () => {
it("returns the full original string when no params", () => {
const original = encoding.getOriginalString();
expect(original).toEqual(originalString);
});
it("throws an error when `begin` is out of range", () => {
expect(() => encoding.getOriginalString(1000)).toThrow();
});
it("returns the original string starting at the specified index", () => {
const original = encoding.getOriginalString(3);
expect(original).toEqual("name is john");
});
it("throws an error when `end` is out of range", () => {
expect(() => encoding.getOriginalString(0, 1000)).toThrow();
});
it("returns the original string between the two specified indexes", () => {
const original = encoding.getOriginalString(3, 7);
expect(original).toEqual("name");
});
describe("with only a negative `begin`", () => {
it("returns the original string counting from the end when in the range", () => {
const original = encoding.getOriginalString(-4);
expect(original).toEqual("john");
});
it("throws an error when out of range", () => {
expect(() => encoding.getOriginalString(-1000)).toThrow();
});
});
describe("with a positive `begin` and a negative `end`", () => {
it("returns the original string when resulting range is valid", () => {
const original = encoding.getOriginalString(3, -5);
expect(original).toEqual("name is");
});
it("throws an error when resulting `end` index is lower than `begin`", () => {
expect(() => encoding.getOriginalString(7, -10)).toThrow();
});
it("throws an error when `begin` is out of range", () => {
expect(() => encoding.getOriginalString(1000, -10)).toThrow();
});
it("throws an error when resulting `end` index is out of range", () => {
expect(() => encoding.getOriginalString(7, -1000)).toThrow();
});
});
describe("with a negative `begin` and a positive `end`", () => {
it("returns the original string when resulting range is valid", () => {
const original = encoding.getOriginalString(-7, 10);
expect(original).toEqual("is");
});
it("throws an error when resulting `begin` index is upper than `end`", () => {
expect(() => encoding.getOriginalString(-3, 5)).toThrow();
});
it("throws an error when `end` is out of range", () => {
expect(() => encoding.getOriginalString(-5, 1000)).toThrow();
});
it("throws an error when resulting `begin` index is out of range", () => {
expect(() => encoding.getOriginalString(-1000, 10)).toThrow();
});
});
describe("with negatives `begin` and `end`", () => {
it("returns the original string when resulting range is valid", () => {
const original = encoding.getOriginalString(-7, -5);
expect(original).toEqual("is");
});
it("throws an error when resulting `end` index is lower than `begin`", () => {
expect(() => encoding.getOriginalString(-5, -10)).toThrow();
});
it("throws an error when resulting `begin` index is out of range", () => {
expect(() => encoding.getOriginalString(-1000, -10)).toThrow();
});
it("throws an error when resulting `end` index is out of range", () => {
expect(() => encoding.getOriginalString(-10, -1000)).toThrow();
});
});
});
});

View File

@ -227,6 +227,17 @@ interface Encoding {
*/
getTypeIds(): number[];
/**
* Returns the original string
*
* @param [begin] The index from which to start (can be negative).
* @param [end] The index (excluded) to which to stop (can be negative).
* Stopping at the end of the string if not provided.
* @returns The full original string if no parameter is provided,
* otherwise the original string between `begin` and `end`
*/
getOriginalString(begin?: number, end?: number): string;
/**
* Pad the current Encoding at the given length
*

View File

@ -148,6 +148,48 @@ declare_types! {
}
}
method getOriginalString(mut cx) {
// getOriginalString(begin?: number, end?: number)
let this = cx.this();
let guard = cx.lock();
let normalized = this.borrow(&guard).encoding.execute(|encoding| {
encoding.unwrap().get_normalized().clone()
});
let get_index = |x: i32| -> usize {
if x >= 0 {
x as usize
} else {
(normalized.len_original() as i32 + x) as usize
}
};
if let Some(begin_arg) = cx.argument_opt(0) {
let begin = begin_arg.downcast::<JsNumber>().or_throw(&mut cx)?.value() as i32;
let begin_index = get_index(begin);
let end_index = if let Some(end_arg) = cx.argument_opt(1) {
let end = end_arg.downcast::<JsNumber>().or_throw(&mut cx)?.value() as i32;
get_index(end)
} else {
normalized.len_original()
};
let original = normalized.get_range_original(begin_index..end_index);
if let Some(original) = original {
Ok(cx.string(original).upcast())
} else {
cx.throw_error("Error in offsets")
}
} else {
let original = normalized.get_original();
Ok(cx.string(original.to_owned()).upcast())
}
}
method pad(mut cx) {
// pad(length: number, options?: {
// direction?: 'left' | 'right' = 'right',