mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Adding ByteFallback support for tokenizers
. (#1183)
* Adding ByteFallback support for `tokenizers`. Two items added: - A flag `byte_fallback` for the `BPE` model. This will be in charge of using `<0x61>` instead of unk on unknown tokens. - A ByteFallback decoder, which will be in charge of putting everything back into string whenever possible. Showing � when the byte decoding fails (behavior checked against LlamaTokenizer in `transformers`. * Update rustdoc. * Clippy + Add BPE(byte_fallback) into bindings. * Stupid file. * Test artifacts removed. * Update stub. * Fix. * Bad file. * CRITICAL FIX: wrapper order because of untagged.... * Remove prints. * Fixing <16 byte fallback.
This commit is contained in:
8
bindings/node/lib/bindings/decoders.d.ts
vendored
8
bindings/node/lib/bindings/decoders.d.ts
vendored
@ -20,6 +20,14 @@ export function byteLevelDecoder(): Decoder;
|
||||
*/
|
||||
export function wordPieceDecoder(prefix?: string, cleanup?: boolean): Decoder;
|
||||
|
||||
/**
|
||||
* Instantiate a new ByteFallback Decoder
|
||||
* ByteFallback is a simple trick which converts tokens looking like `<0x61>`
|
||||
* to pure bytes, and attempts to make them into a string. If the tokens
|
||||
* cannot be decoded you will get <20> instead for each inconvertable byte token
|
||||
*/
|
||||
export function byteFallbackDecoder(): Decoder;
|
||||
|
||||
/**
|
||||
* Instantiate a new Metaspace
|
||||
*
|
||||
|
@ -3,6 +3,7 @@ const native = require("./native");
|
||||
module.exports = {
|
||||
byteLevelDecoder: native.decoders_ByteLevel,
|
||||
wordPieceDecoder: native.decoders_WordPiece,
|
||||
byteFallbackDecoder: native.decoders_ByteFallback,
|
||||
metaspaceDecoder: native.decoders_Metaspace,
|
||||
bpeDecoder: native.decoders_BPEDecoder,
|
||||
ctcDecoder: native.decoders_CTC,
|
||||
|
@ -1,5 +1,6 @@
|
||||
import {
|
||||
bpeDecoder,
|
||||
byteFallbackDecoder,
|
||||
ctcDecoder,
|
||||
metaspaceDecoder,
|
||||
sequenceDecoder,
|
||||
@ -22,6 +23,27 @@ describe("wordPieceDecoder", () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe("byteFallbackDecoder", () => {
|
||||
it("accepts `undefined` as first parameter", () => {
|
||||
expect(byteFallbackDecoder()).toBeDefined();
|
||||
});
|
||||
|
||||
it("can decode arrays of strings", () => {
|
||||
expect(byteFallbackDecoder().decode(["Hel", "lo"])).toEqual("Hello");
|
||||
expect(byteFallbackDecoder().decode(["<0x61>"])).toEqual("a");
|
||||
expect(byteFallbackDecoder().decode(["<0x61>"])).toEqual("a");
|
||||
expect(byteFallbackDecoder().decode(["My", " na", "me"])).toEqual("My name");
|
||||
expect(byteFallbackDecoder().decode(["<0x61>"])).toEqual("a");
|
||||
expect(byteFallbackDecoder().decode(["<0xE5>"])).toEqual("<22>");
|
||||
expect(byteFallbackDecoder().decode(["<0xE5>", "<0x8f>"])).toEqual("<22><>");
|
||||
expect(byteFallbackDecoder().decode(["<0xE5>", "<0x8f>", "<0xab>"])).toEqual("叫");
|
||||
expect(byteFallbackDecoder().decode(["<0xE5>", "<0x8f>", "a"])).toEqual("<22><>a");
|
||||
expect(byteFallbackDecoder().decode(["<0xE5>", "<0x8f>", "<0xab>", "a"])).toEqual(
|
||||
"叫a"
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("metaspaceDecoder", () => {
|
||||
it("accepts `undefined` as first parameter", () => {
|
||||
expect(metaspaceDecoder(undefined)).toBeDefined();
|
||||
|
@ -72,6 +72,16 @@ fn wordpiece(mut cx: FunctionContext) -> JsResult<JsDecoder> {
|
||||
Ok(decoder)
|
||||
}
|
||||
|
||||
/// byte_fallback()
|
||||
fn byte_fallback(mut cx: FunctionContext) -> JsResult<JsDecoder> {
|
||||
let mut decoder = JsDecoder::new::<_, JsDecoder, _>(&mut cx, vec![])?;
|
||||
let guard = cx.lock();
|
||||
decoder.borrow_mut(&guard).decoder = Some(Arc::new(
|
||||
tk::decoders::byte_fallback::ByteFallback::new().into(),
|
||||
));
|
||||
Ok(decoder)
|
||||
}
|
||||
|
||||
/// metaspace(replacement: String = "_", add_prefix_space: bool = true)
|
||||
fn metaspace(mut cx: FunctionContext) -> JsResult<JsDecoder> {
|
||||
let replacement = cx.extract_opt::<char>(0)?.unwrap_or('▁');
|
||||
@ -147,6 +157,7 @@ fn sequence(mut cx: FunctionContext) -> JsResult<JsDecoder> {
|
||||
pub fn register(m: &mut ModuleContext, prefix: &str) -> NeonResult<()> {
|
||||
m.export_function(&format!("{}_ByteLevel", prefix), byte_level)?;
|
||||
m.export_function(&format!("{}_WordPiece", prefix), wordpiece)?;
|
||||
m.export_function(&format!("{}_ByteFallback", prefix), byte_fallback)?;
|
||||
m.export_function(&format!("{}_Metaspace", prefix), metaspace)?;
|
||||
m.export_function(&format!("{}_BPEDecoder", prefix), bpe_decoder)?;
|
||||
m.export_function(&format!("{}_CTC", prefix), ctc_decoder)?;
|
||||
|
@ -132,6 +132,7 @@ struct BpeOptions {
|
||||
continuing_subword_prefix: Option<String>,
|
||||
end_of_word_suffix: Option<String>,
|
||||
fuse_unk: Option<bool>,
|
||||
byte_fallback: Option<bool>,
|
||||
}
|
||||
impl BpeOptions {
|
||||
fn apply_to_bpe_builder(self, mut builder: BpeBuilder) -> BpeBuilder {
|
||||
@ -153,6 +154,9 @@ impl BpeOptions {
|
||||
if let Some(fuse_unk) = self.fuse_unk {
|
||||
builder = builder.fuse_unk(fuse_unk);
|
||||
}
|
||||
if let Some(byte_fallback) = self.byte_fallback {
|
||||
builder = builder.byte_fallback(byte_fallback);
|
||||
}
|
||||
|
||||
builder
|
||||
}
|
||||
|
Reference in New Issue
Block a user