Adding ByteFallback support for tokenizers. (#1183)

* Adding ByteFallback support for `tokenizers`.

Two items added:

- A flag `byte_fallback` for the `BPE` model. This will be in charge
  of using `<0x61>` instead of unk on unknown tokens.
- A ByteFallback decoder, which will be in charge of putting everything
  back into string whenever possible. Showing � when the byte decoding
  fails (behavior checked against LlamaTokenizer in `transformers`.

* Update rustdoc.

* Clippy + Add BPE(byte_fallback) into bindings.

* Stupid file.

* Test artifacts removed.

* Update stub.

* Fix.

* Bad file.

* CRITICAL FIX: wrapper order because of untagged....

* Remove prints.

* Fixing <16 byte fallback.
This commit is contained in:
Nicolas Patry
2023-03-23 16:04:32 +01:00
committed by GitHub
parent b8fbea00a9
commit 73637a0004
16 changed files with 359 additions and 21 deletions

View File

@ -20,6 +20,14 @@ export function byteLevelDecoder(): Decoder;
*/
export function wordPieceDecoder(prefix?: string, cleanup?: boolean): Decoder;
/**
* Instantiate a new ByteFallback Decoder
* ByteFallback is a simple trick which converts tokens looking like `<0x61>`
* to pure bytes, and attempts to make them into a string. If the tokens
* cannot be decoded you will get <20> instead for each inconvertable byte token
*/
export function byteFallbackDecoder(): Decoder;
/**
* Instantiate a new Metaspace
*

View File

@ -3,6 +3,7 @@ const native = require("./native");
module.exports = {
byteLevelDecoder: native.decoders_ByteLevel,
wordPieceDecoder: native.decoders_WordPiece,
byteFallbackDecoder: native.decoders_ByteFallback,
metaspaceDecoder: native.decoders_Metaspace,
bpeDecoder: native.decoders_BPEDecoder,
ctcDecoder: native.decoders_CTC,

View File

@ -1,5 +1,6 @@
import {
bpeDecoder,
byteFallbackDecoder,
ctcDecoder,
metaspaceDecoder,
sequenceDecoder,
@ -22,6 +23,27 @@ describe("wordPieceDecoder", () => {
});
});
describe("byteFallbackDecoder", () => {
it("accepts `undefined` as first parameter", () => {
expect(byteFallbackDecoder()).toBeDefined();
});
it("can decode arrays of strings", () => {
expect(byteFallbackDecoder().decode(["Hel", "lo"])).toEqual("Hello");
expect(byteFallbackDecoder().decode(["<0x61>"])).toEqual("a");
expect(byteFallbackDecoder().decode(["<0x61>"])).toEqual("a");
expect(byteFallbackDecoder().decode(["My", " na", "me"])).toEqual("My name");
expect(byteFallbackDecoder().decode(["<0x61>"])).toEqual("a");
expect(byteFallbackDecoder().decode(["<0xE5>"])).toEqual("<22>");
expect(byteFallbackDecoder().decode(["<0xE5>", "<0x8f>"])).toEqual("<22><>");
expect(byteFallbackDecoder().decode(["<0xE5>", "<0x8f>", "<0xab>"])).toEqual("叫");
expect(byteFallbackDecoder().decode(["<0xE5>", "<0x8f>", "a"])).toEqual("<22><>a");
expect(byteFallbackDecoder().decode(["<0xE5>", "<0x8f>", "<0xab>", "a"])).toEqual(
"叫a"
);
});
});
describe("metaspaceDecoder", () => {
it("accepts `undefined` as first parameter", () => {
expect(metaspaceDecoder(undefined)).toBeDefined();

View File

@ -72,6 +72,16 @@ fn wordpiece(mut cx: FunctionContext) -> JsResult<JsDecoder> {
Ok(decoder)
}
/// byte_fallback()
fn byte_fallback(mut cx: FunctionContext) -> JsResult<JsDecoder> {
let mut decoder = JsDecoder::new::<_, JsDecoder, _>(&mut cx, vec![])?;
let guard = cx.lock();
decoder.borrow_mut(&guard).decoder = Some(Arc::new(
tk::decoders::byte_fallback::ByteFallback::new().into(),
));
Ok(decoder)
}
/// metaspace(replacement: String = "_", add_prefix_space: bool = true)
fn metaspace(mut cx: FunctionContext) -> JsResult<JsDecoder> {
let replacement = cx.extract_opt::<char>(0)?.unwrap_or('▁');
@ -147,6 +157,7 @@ fn sequence(mut cx: FunctionContext) -> JsResult<JsDecoder> {
pub fn register(m: &mut ModuleContext, prefix: &str) -> NeonResult<()> {
m.export_function(&format!("{}_ByteLevel", prefix), byte_level)?;
m.export_function(&format!("{}_WordPiece", prefix), wordpiece)?;
m.export_function(&format!("{}_ByteFallback", prefix), byte_fallback)?;
m.export_function(&format!("{}_Metaspace", prefix), metaspace)?;
m.export_function(&format!("{}_BPEDecoder", prefix), bpe_decoder)?;
m.export_function(&format!("{}_CTC", prefix), ctc_decoder)?;

View File

@ -132,6 +132,7 @@ struct BpeOptions {
continuing_subword_prefix: Option<String>,
end_of_word_suffix: Option<String>,
fuse_unk: Option<bool>,
byte_fallback: Option<bool>,
}
impl BpeOptions {
fn apply_to_bpe_builder(self, mut builder: BpeBuilder) -> BpeBuilder {
@ -153,6 +154,9 @@ impl BpeOptions {
if let Some(fuse_unk) = self.fuse_unk {
builder = builder.fuse_unk(fuse_unk);
}
if let Some(byte_fallback) = self.byte_fallback {
builder = builder.byte_fallback(byte_fallback);
}
builder
}