Adding 2 new decoders: (#1196)

* Adding 2 new decoders:

- Fuse will simply concatenate all tokens into 1 string
- Strip will remove n char from left or right

Sequence(Replace("_", " "), Fuse(), Strip(1, 0)) should be what we want
for the `Metaspace` thing.

- Note: Added a new dependency from better parsing of decoders.
This is due to untagged enums which can match anything the `MustBe`
ensure there's no issue between Fuse and ByteFallback.
Since both are new the chances for backward incompatibility is low.

* Fixing picking/unpickling (using default args.).

* Stub.

* Black.

* Fixing node.
This commit is contained in:
Nicolas Patry
2023-03-24 00:50:54 +01:00
committed by GitHub
parent d2c8190a0f
commit e4aea890d5
13 changed files with 311 additions and 7 deletions

View File

@ -35,6 +35,18 @@ export function wordPieceDecoder(prefix?: string, cleanup?: boolean): Decoder;
*/
export function byteFallbackDecoder(): Decoder;
/**
* Instantiate a new Fuse Decoder which fuses all tokens into one string
*/
export function fuseDecoder(): Decoder;
/**
* Instantiate a new Strip Decoder
* @param [left] The number of chars to remove from the left of each token
* @param [right] The number of chars to remove from the right of each token
*/
export function stripDecoder(left: number, right: number): Decoder;
/**
* Instantiate a new Metaspace
*

View File

@ -5,6 +5,8 @@ module.exports = {
replaceDecoder: native.decoders_Replace,
wordPieceDecoder: native.decoders_WordPiece,
byteFallbackDecoder: native.decoders_ByteFallback,
fuseDecoder: native.decoders_Fuse,
stripDecoder: native.decoders_Strip,
metaspaceDecoder: native.decoders_Metaspace,
bpeDecoder: native.decoders_BPEDecoder,
ctcDecoder: native.decoders_CTC,

View File

@ -2,9 +2,11 @@ import {
bpeDecoder,
byteFallbackDecoder,
ctcDecoder,
fuseDecoder,
metaspaceDecoder,
replaceDecoder,
sequenceDecoder,
stripDecoder,
wordPieceDecoder,
} from "./decoders";
@ -51,6 +53,26 @@ describe("replaceDecoder", () => {
});
});
describe("fuseDecoder", () => {
it("accepts `undefined` as first parameter", () => {
expect(fuseDecoder()).toBeDefined();
});
it("can decode arrays of strings", () => {
expect(fuseDecoder().decode(["Hel", "lo"])).toEqual("Hello");
});
});
describe("stripDecoder", () => {
it("accepts `undefined` as first parameter", () => {
expect(stripDecoder(0, 0)).toBeDefined();
});
it("can decode arrays of strings", () => {
expect(stripDecoder(1, 0).decode(["Hel", "lo"])).toEqual("elo");
});
});
describe("metaspaceDecoder", () => {
it("accepts `undefined` as first parameter", () => {
expect(metaspaceDecoder(undefined)).toBeDefined();

View File

@ -96,6 +96,26 @@ fn byte_fallback(mut cx: FunctionContext) -> JsResult<JsDecoder> {
Ok(decoder)
}
/// fuse()
fn fuse(mut cx: FunctionContext) -> JsResult<JsDecoder> {
let mut decoder = JsDecoder::new::<_, JsDecoder, _>(&mut cx, vec![])?;
let guard = cx.lock();
decoder.borrow_mut(&guard).decoder = Some(Arc::new(tk::decoders::fuse::Fuse::new().into()));
Ok(decoder)
}
/// strip()
fn strip(mut cx: FunctionContext) -> JsResult<JsDecoder> {
let left: usize = cx.extract(0)?;
let right: usize = cx.extract(1)?;
let mut decoder = JsDecoder::new::<_, JsDecoder, _>(&mut cx, vec![])?;
let guard = cx.lock();
decoder.borrow_mut(&guard).decoder = Some(Arc::new(
tk::decoders::strip::Strip::new(left, right).into(),
));
Ok(decoder)
}
/// metaspace(replacement: String = "_", add_prefix_space: bool = true)
fn metaspace(mut cx: FunctionContext) -> JsResult<JsDecoder> {
let replacement = cx.extract_opt::<char>(0)?.unwrap_or('▁');
@ -173,6 +193,8 @@ pub fn register(m: &mut ModuleContext, prefix: &str) -> NeonResult<()> {
m.export_function(&format!("{}_Replace", prefix), replace)?;
m.export_function(&format!("{}_WordPiece", prefix), wordpiece)?;
m.export_function(&format!("{}_ByteFallback", prefix), byte_fallback)?;
m.export_function(&format!("{}_Fuse", prefix), fuse)?;
m.export_function(&format!("{}_Strip", prefix), strip)?;
m.export_function(&format!("{}_Metaspace", prefix), metaspace)?;
m.export_function(&format!("{}_BPEDecoder", prefix), bpe_decoder)?;
m.export_function(&format!("{}_CTC", prefix), ctc_decoder)?;