mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Adding ByteFallback support for tokenizers
. (#1183)
* Adding ByteFallback support for `tokenizers`. Two items added: - A flag `byte_fallback` for the `BPE` model. This will be in charge of using `<0x61>` instead of unk on unknown tokens. - A ByteFallback decoder, which will be in charge of putting everything back into string whenever possible. Showing � when the byte decoding fails (behavior checked against LlamaTokenizer in `transformers`. * Update rustdoc. * Clippy + Add BPE(byte_fallback) into bindings. * Stupid file. * Test artifacts removed. * Update stub. * Fix. * Bad file. * CRITICAL FIX: wrapper order because of untagged.... * Remove prints. * Fixing <16 byte fallback.
This commit is contained in:
8
bindings/node/lib/bindings/decoders.d.ts
vendored
8
bindings/node/lib/bindings/decoders.d.ts
vendored
@ -20,6 +20,14 @@ export function byteLevelDecoder(): Decoder;
|
||||
*/
|
||||
export function wordPieceDecoder(prefix?: string, cleanup?: boolean): Decoder;
|
||||
|
||||
/**
|
||||
* Instantiate a new ByteFallback Decoder
|
||||
* ByteFallback is a simple trick which converts tokens looking like `<0x61>`
|
||||
* to pure bytes, and attempts to make them into a string. If the tokens
|
||||
* cannot be decoded you will get <20> instead for each inconvertable byte token
|
||||
*/
|
||||
export function byteFallbackDecoder(): Decoder;
|
||||
|
||||
/**
|
||||
* Instantiate a new Metaspace
|
||||
*
|
||||
|
@ -3,6 +3,7 @@ const native = require("./native");
|
||||
module.exports = {
|
||||
byteLevelDecoder: native.decoders_ByteLevel,
|
||||
wordPieceDecoder: native.decoders_WordPiece,
|
||||
byteFallbackDecoder: native.decoders_ByteFallback,
|
||||
metaspaceDecoder: native.decoders_Metaspace,
|
||||
bpeDecoder: native.decoders_BPEDecoder,
|
||||
ctcDecoder: native.decoders_CTC,
|
||||
|
@ -1,5 +1,6 @@
|
||||
import {
|
||||
bpeDecoder,
|
||||
byteFallbackDecoder,
|
||||
ctcDecoder,
|
||||
metaspaceDecoder,
|
||||
sequenceDecoder,
|
||||
@ -22,6 +23,27 @@ describe("wordPieceDecoder", () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe("byteFallbackDecoder", () => {
|
||||
it("accepts `undefined` as first parameter", () => {
|
||||
expect(byteFallbackDecoder()).toBeDefined();
|
||||
});
|
||||
|
||||
it("can decode arrays of strings", () => {
|
||||
expect(byteFallbackDecoder().decode(["Hel", "lo"])).toEqual("Hello");
|
||||
expect(byteFallbackDecoder().decode(["<0x61>"])).toEqual("a");
|
||||
expect(byteFallbackDecoder().decode(["<0x61>"])).toEqual("a");
|
||||
expect(byteFallbackDecoder().decode(["My", " na", "me"])).toEqual("My name");
|
||||
expect(byteFallbackDecoder().decode(["<0x61>"])).toEqual("a");
|
||||
expect(byteFallbackDecoder().decode(["<0xE5>"])).toEqual("<22>");
|
||||
expect(byteFallbackDecoder().decode(["<0xE5>", "<0x8f>"])).toEqual("<22><>");
|
||||
expect(byteFallbackDecoder().decode(["<0xE5>", "<0x8f>", "<0xab>"])).toEqual("叫");
|
||||
expect(byteFallbackDecoder().decode(["<0xE5>", "<0x8f>", "a"])).toEqual("<22><>a");
|
||||
expect(byteFallbackDecoder().decode(["<0xE5>", "<0x8f>", "<0xab>", "a"])).toEqual(
|
||||
"叫a"
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("metaspaceDecoder", () => {
|
||||
it("accepts `undefined` as first parameter", () => {
|
||||
expect(metaspaceDecoder(undefined)).toBeDefined();
|
||||
|
@ -72,6 +72,16 @@ fn wordpiece(mut cx: FunctionContext) -> JsResult<JsDecoder> {
|
||||
Ok(decoder)
|
||||
}
|
||||
|
||||
/// byte_fallback()
|
||||
fn byte_fallback(mut cx: FunctionContext) -> JsResult<JsDecoder> {
|
||||
let mut decoder = JsDecoder::new::<_, JsDecoder, _>(&mut cx, vec![])?;
|
||||
let guard = cx.lock();
|
||||
decoder.borrow_mut(&guard).decoder = Some(Arc::new(
|
||||
tk::decoders::byte_fallback::ByteFallback::new().into(),
|
||||
));
|
||||
Ok(decoder)
|
||||
}
|
||||
|
||||
/// metaspace(replacement: String = "_", add_prefix_space: bool = true)
|
||||
fn metaspace(mut cx: FunctionContext) -> JsResult<JsDecoder> {
|
||||
let replacement = cx.extract_opt::<char>(0)?.unwrap_or('▁');
|
||||
@ -147,6 +157,7 @@ fn sequence(mut cx: FunctionContext) -> JsResult<JsDecoder> {
|
||||
pub fn register(m: &mut ModuleContext, prefix: &str) -> NeonResult<()> {
|
||||
m.export_function(&format!("{}_ByteLevel", prefix), byte_level)?;
|
||||
m.export_function(&format!("{}_WordPiece", prefix), wordpiece)?;
|
||||
m.export_function(&format!("{}_ByteFallback", prefix), byte_fallback)?;
|
||||
m.export_function(&format!("{}_Metaspace", prefix), metaspace)?;
|
||||
m.export_function(&format!("{}_BPEDecoder", prefix), bpe_decoder)?;
|
||||
m.export_function(&format!("{}_CTC", prefix), ctc_decoder)?;
|
||||
|
@ -132,6 +132,7 @@ struct BpeOptions {
|
||||
continuing_subword_prefix: Option<String>,
|
||||
end_of_word_suffix: Option<String>,
|
||||
fuse_unk: Option<bool>,
|
||||
byte_fallback: Option<bool>,
|
||||
}
|
||||
impl BpeOptions {
|
||||
fn apply_to_bpe_builder(self, mut builder: BpeBuilder) -> BpeBuilder {
|
||||
@ -153,6 +154,9 @@ impl BpeOptions {
|
||||
if let Some(fuse_unk) = self.fuse_unk {
|
||||
builder = builder.fuse_unk(fuse_unk);
|
||||
}
|
||||
if let Some(byte_fallback) = self.byte_fallback {
|
||||
builder = builder.byte_fallback(byte_fallback);
|
||||
}
|
||||
|
||||
builder
|
||||
}
|
||||
|
@ -4,6 +4,7 @@ from .. import decoders
|
||||
Decoder = decoders.Decoder
|
||||
ByteLevel = decoders.ByteLevel
|
||||
WordPiece = decoders.WordPiece
|
||||
ByteFallback = decoders.ByteFallback
|
||||
Metaspace = decoders.Metaspace
|
||||
BPEDecoder = decoders.BPEDecoder
|
||||
CTC = decoders.CTC
|
||||
|
@ -45,6 +45,30 @@ class BPEDecoder(Decoder):
|
||||
"""
|
||||
pass
|
||||
|
||||
class ByteFallback(Decoder):
|
||||
"""
|
||||
ByteFallback Decoder
|
||||
ByteFallback is a simple trick which converts tokens looking like `<0x61>`
|
||||
to pure bytes, and attempts to make them into a string. If the tokens
|
||||
cannot be decoded you will get <20> instead for each inconvertable byte token
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
def decode(self, tokens):
|
||||
"""
|
||||
Decode the given list of tokens to a final string
|
||||
|
||||
Args:
|
||||
tokens (:obj:`List[str]`):
|
||||
The list of tokens to decode
|
||||
|
||||
Returns:
|
||||
:obj:`str`: The decoded string
|
||||
"""
|
||||
pass
|
||||
|
||||
class ByteLevel(Decoder):
|
||||
"""
|
||||
ByteLevel Decoder
|
||||
|
@ -106,6 +106,9 @@ class BPE(Model):
|
||||
|
||||
fuse_unk (:obj:`bool`, `optional`):
|
||||
Whether to fuse any subsequent unknown tokens into a single one
|
||||
|
||||
byte_fallback (:obj:`bool`, `optional`):
|
||||
Whether to use spm byte-fallback trick (defaults to False)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -118,6 +121,7 @@ class BPE(Model):
|
||||
continuing_subword_prefix=None,
|
||||
end_of_word_suffix=None,
|
||||
fuse_unk=None,
|
||||
byte_fallback=False,
|
||||
):
|
||||
pass
|
||||
@staticmethod
|
||||
|
@ -7,6 +7,7 @@ use pyo3::types::*;
|
||||
use serde::de::Error;
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
use tk::decoders::bpe::BPEDecoder;
|
||||
use tk::decoders::byte_fallback::ByteFallback;
|
||||
use tk::decoders::byte_level::ByteLevel;
|
||||
use tk::decoders::ctc::CTC;
|
||||
use tk::decoders::metaspace::Metaspace;
|
||||
@ -41,6 +42,9 @@ impl PyDecoder {
|
||||
PyDecoderWrapper::Wrapped(inner) => match &*inner.as_ref().read().unwrap() {
|
||||
DecoderWrapper::Metaspace(_) => Py::new(py, (PyMetaspaceDec {}, base))?.into_py(py),
|
||||
DecoderWrapper::WordPiece(_) => Py::new(py, (PyWordPieceDec {}, base))?.into_py(py),
|
||||
DecoderWrapper::ByteFallback(_) => {
|
||||
Py::new(py, (PyByteFallbackDec {}, base))?.into_py(py)
|
||||
}
|
||||
DecoderWrapper::ByteLevel(_) => Py::new(py, (PyByteLevelDec {}, base))?.into_py(py),
|
||||
DecoderWrapper::BPE(_) => Py::new(py, (PyBPEDecoder {}, base))?.into_py(py),
|
||||
DecoderWrapper::CTC(_) => Py::new(py, (PyCTCDecoder {}, base))?.into_py(py),
|
||||
@ -196,6 +200,23 @@ impl PyWordPieceDec {
|
||||
}
|
||||
}
|
||||
|
||||
/// ByteFallback Decoder
|
||||
/// ByteFallback is a simple trick which converts tokens looking like `<0x61>`
|
||||
/// to pure bytes, and attempts to make them into a string. If the tokens
|
||||
/// cannot be decoded you will get <20> instead for each inconvertable byte token
|
||||
///
|
||||
#[pyclass(extends=PyDecoder, module = "tokenizers.decoders", name = "ByteFallback")]
|
||||
#[pyo3(text_signature = "(self)")]
|
||||
pub struct PyByteFallbackDec {}
|
||||
#[pymethods]
|
||||
impl PyByteFallbackDec {
|
||||
#[new]
|
||||
#[pyo3(signature = ())]
|
||||
fn new() -> (Self, PyDecoder) {
|
||||
(PyByteFallbackDec {}, ByteFallback::new().into())
|
||||
}
|
||||
}
|
||||
|
||||
/// Metaspace Decoder
|
||||
///
|
||||
/// Args:
|
||||
@ -453,6 +474,7 @@ pub fn decoders(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||
m.add_class::<PyDecoder>()?;
|
||||
m.add_class::<PyByteLevelDec>()?;
|
||||
m.add_class::<PyWordPieceDec>()?;
|
||||
m.add_class::<PyByteFallbackDec>()?;
|
||||
m.add_class::<PyMetaspaceDec>()?;
|
||||
m.add_class::<PyBPEDecoder>()?;
|
||||
m.add_class::<PyCTCDecoder>()?;
|
||||
|
@ -249,9 +249,12 @@ impl PyModel {
|
||||
///
|
||||
/// fuse_unk (:obj:`bool`, `optional`):
|
||||
/// Whether to fuse any subsequent unknown tokens into a single one
|
||||
///
|
||||
/// byte_fallback (:obj:`bool`, `optional`):
|
||||
/// Whether to use spm byte-fallback trick (defaults to False)
|
||||
#[pyclass(extends=PyModel, module = "tokenizers.models", name = "BPE")]
|
||||
#[pyo3(
|
||||
text_signature = "(self, vocab=None, merges=None, cache_capacity=None, dropout=None, unk_token=None, continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=None)"
|
||||
text_signature = "(self, vocab=None, merges=None, cache_capacity=None, dropout=None, unk_token=None, continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=None, byte_fallback=False)"
|
||||
)]
|
||||
pub struct PyBPE {}
|
||||
|
||||
@ -277,6 +280,7 @@ impl PyBPE {
|
||||
}
|
||||
"end_of_word_suffix" => builder = builder.end_of_word_suffix(value.extract()?),
|
||||
"fuse_unk" => builder = builder.fuse_unk(value.extract()?),
|
||||
"byte_fallback" => builder = builder.byte_fallback(value.extract()?),
|
||||
_ => println!("Ignored unknown kwarg option {}", key),
|
||||
};
|
||||
}
|
||||
@ -385,6 +389,16 @@ impl PyBPE {
|
||||
setter!(self_, BPE, fuse_unk, fuse_unk);
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_byte_fallback(self_: PyRef<Self>) -> bool {
|
||||
getter!(self_, BPE, byte_fallback)
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_byte_fallback(self_: PyRef<Self>, byte_fallback: bool) {
|
||||
setter!(self_, BPE, byte_fallback, byte_fallback);
|
||||
}
|
||||
|
||||
#[new]
|
||||
#[pyo3(signature = (vocab=None, merges=None, **kwargs))]
|
||||
fn new(
|
||||
|
@ -3,7 +3,7 @@ import pickle
|
||||
|
||||
import pytest
|
||||
|
||||
from tokenizers.decoders import CTC, BPEDecoder, ByteLevel, Decoder, Metaspace, Sequence, WordPiece
|
||||
from tokenizers.decoders import CTC, BPEDecoder, ByteLevel, Decoder, Metaspace, Sequence, WordPiece, ByteFallback
|
||||
|
||||
|
||||
class TestByteLevel:
|
||||
@ -54,6 +54,24 @@ class TestWordPiece:
|
||||
assert decoder.cleanup == True
|
||||
|
||||
|
||||
class TestByteFallback:
|
||||
def test_instantiate(self):
|
||||
assert ByteFallback() is not None
|
||||
assert isinstance(ByteFallback(), Decoder)
|
||||
assert isinstance(ByteFallback(), ByteFallback)
|
||||
assert isinstance(pickle.loads(pickle.dumps(ByteFallback())), ByteFallback)
|
||||
|
||||
def test_decoding(self):
|
||||
decoder = ByteFallback()
|
||||
assert decoder.decode(["My", " na", "me"]) == "My name"
|
||||
assert decoder.decode(["<0x61>"]) == "a"
|
||||
assert decoder.decode(["<0xE5>"]) == "<EFBFBD>"
|
||||
assert decoder.decode(["<0xE5>", "<0x8f>"]) == "<EFBFBD><EFBFBD>"
|
||||
assert decoder.decode(["<0xE5>", "<0x8f>", "<0xab>"]) == "叫"
|
||||
assert decoder.decode(["<0xE5>", "<0x8f>", "a"]) == "<EFBFBD><EFBFBD>a"
|
||||
assert decoder.decode(["<0xE5>", "<0x8f>", "<0xab>", "a"]) == "叫a"
|
||||
|
||||
|
||||
class TestMetaspace:
|
||||
def test_instantiate(self):
|
||||
assert Metaspace() is not None
|
||||
|
@ -54,6 +54,7 @@ class TestBPE:
|
||||
assert model.continuing_subword_prefix == "__prefix__"
|
||||
assert model.end_of_word_suffix == "__suffix__"
|
||||
assert model.fuse_unk == False
|
||||
assert model.byte_fallback == False
|
||||
|
||||
# Modify these
|
||||
model.dropout = 0.1
|
||||
@ -66,6 +67,8 @@ class TestBPE:
|
||||
assert model.end_of_word_suffix == "suff"
|
||||
model.fuse_unk = True
|
||||
assert model.fuse_unk == True
|
||||
model.byte_fallback = True
|
||||
assert model.byte_fallback == True
|
||||
|
||||
|
||||
class TestWordPiece:
|
||||
|
108
tokenizers/src/decoders/byte_fallback.rs
Normal file
108
tokenizers/src/decoders/byte_fallback.rs
Normal file
@ -0,0 +1,108 @@
|
||||
use crate::tokenizer::{Decoder, Result};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Deserialize, Clone, Debug, Serialize, Default)]
|
||||
/// ByteFallback is a simple trick which converts tokens looking like `<0x61>`
|
||||
/// to pure bytes, and attempts to make them into a string. If the tokens
|
||||
/// cannot be decoded you will get <20> instead for each inconvertable byte token
|
||||
#[serde(tag = "type")]
|
||||
#[non_exhaustive]
|
||||
pub struct ByteFallback {}
|
||||
|
||||
impl ByteFallback {
|
||||
pub fn new() -> Self {
|
||||
Self {}
|
||||
}
|
||||
}
|
||||
|
||||
impl Decoder for ByteFallback {
|
||||
fn decode_chain(&self, tokens: Vec<String>) -> Result<Vec<String>> {
|
||||
let mut new_tokens: Vec<String> = vec![];
|
||||
let mut previous_byte_tokens: Vec<u8> = vec![];
|
||||
|
||||
for token in tokens {
|
||||
let bytes = if token.len() == 6 && token.starts_with("<0x") && token.ends_with('>') {
|
||||
if let Ok(byte) = u8::from_str_radix(&token[3..5], 16) {
|
||||
Some(byte)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
if let Some(bytes) = bytes {
|
||||
previous_byte_tokens.push(bytes);
|
||||
} else {
|
||||
if !previous_byte_tokens.is_empty() {
|
||||
if let Ok(string) = String::from_utf8(previous_byte_tokens.clone()) {
|
||||
new_tokens.push(string);
|
||||
} else {
|
||||
for _ in 0..previous_byte_tokens.len() {
|
||||
new_tokens.push("<EFBFBD>".into());
|
||||
}
|
||||
}
|
||||
previous_byte_tokens.clear();
|
||||
}
|
||||
new_tokens.push(token);
|
||||
}
|
||||
}
|
||||
if !previous_byte_tokens.is_empty() {
|
||||
if let Ok(string) = String::from_utf8(previous_byte_tokens.clone()) {
|
||||
new_tokens.push(string);
|
||||
} else {
|
||||
for _ in 0..previous_byte_tokens.len() {
|
||||
new_tokens.push("<EFBFBD>".into());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(new_tokens)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn decode() {
|
||||
let decoder = ByteFallback::new();
|
||||
let res = decoder
|
||||
.decode_chain(vec!["Hey".into(), "friend!".into()])
|
||||
.unwrap();
|
||||
assert_eq!(res, vec!["Hey", "friend!"]);
|
||||
|
||||
let res = decoder.decode_chain(vec!["<0x61>".into()]).unwrap();
|
||||
assert_eq!(res, vec!["a"]);
|
||||
|
||||
let res = decoder.decode_chain(vec!["<0xE5>".into()]).unwrap();
|
||||
assert_eq!(res, vec!["<EFBFBD>"]);
|
||||
|
||||
let res = decoder
|
||||
.decode_chain(vec!["<0xE5>".into(), "<0x8f>".into()])
|
||||
.unwrap();
|
||||
assert_eq!(res, vec!["<EFBFBD>", "<EFBFBD>"]);
|
||||
|
||||
// 叫
|
||||
let res = decoder
|
||||
.decode_chain(vec!["<0xE5>".into(), "<0x8f>".into(), "<0xab>".into()])
|
||||
.unwrap();
|
||||
assert_eq!(res, vec!["叫"]);
|
||||
|
||||
let res = decoder
|
||||
.decode_chain(vec![
|
||||
"<0xE5>".into(),
|
||||
"<0x8f>".into(),
|
||||
"<0xab>".into(),
|
||||
"a".into(),
|
||||
])
|
||||
.unwrap();
|
||||
assert_eq!(res, vec!["叫", "a"]);
|
||||
|
||||
let res = decoder
|
||||
.decode_chain(vec!["<0xE5>".into(), "<0x8f>".into(), "a".into()])
|
||||
.unwrap();
|
||||
assert_eq!(res, vec!["<EFBFBD>", "<EFBFBD>", "a"]);
|
||||
}
|
||||
}
|
@ -1,4 +1,5 @@
|
||||
pub mod bpe;
|
||||
pub mod byte_fallback;
|
||||
pub mod ctc;
|
||||
pub mod sequence;
|
||||
pub mod wordpiece;
|
||||
@ -10,6 +11,7 @@ pub use super::pre_tokenizers::metaspace;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::decoders::bpe::BPEDecoder;
|
||||
use crate::decoders::byte_fallback::ByteFallback;
|
||||
use crate::decoders::ctc::CTC;
|
||||
use crate::decoders::sequence::Sequence;
|
||||
use crate::decoders::wordpiece::WordPiece;
|
||||
@ -26,6 +28,11 @@ pub enum DecoderWrapper {
|
||||
Metaspace(Metaspace),
|
||||
CTC(CTC),
|
||||
Sequence(Sequence),
|
||||
// XXX: This is an untagged enum, which unfortunately means order
|
||||
// is **CRITICAL**. We absolutely need to make sure order is correct.
|
||||
// Since byte fallback is parameter free, is **has** to be last, and will
|
||||
// unfortunately match pretty much everything.
|
||||
ByteFallback(ByteFallback),
|
||||
}
|
||||
|
||||
impl Decoder for DecoderWrapper {
|
||||
@ -37,13 +44,28 @@ impl Decoder for DecoderWrapper {
|
||||
Self::WordPiece(wp) => wp.decode_chain(tokens),
|
||||
Self::CTC(ctc) => ctc.decode_chain(tokens),
|
||||
Self::Sequence(seq) => seq.decode_chain(tokens),
|
||||
Self::ByteFallback(bf) => bf.decode_chain(tokens),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl_enum_from!(BPEDecoder, DecoderWrapper, BPE);
|
||||
impl_enum_from!(ByteLevel, DecoderWrapper, ByteLevel);
|
||||
impl_enum_from!(ByteFallback, DecoderWrapper, ByteFallback);
|
||||
impl_enum_from!(Metaspace, DecoderWrapper, Metaspace);
|
||||
impl_enum_from!(WordPiece, DecoderWrapper, WordPiece);
|
||||
impl_enum_from!(CTC, DecoderWrapper, CTC);
|
||||
impl_enum_from!(Sequence, DecoderWrapper, Sequence);
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn decoder_serialization() {
|
||||
let json = r#"{"type":"Sequence","decoders":[{"type":"ByteFallback"},{"type":"Metaspace","replacement":"▁","add_prefix_space":true}]}"#;
|
||||
let decoder: DecoderWrapper = serde_json::from_str(json).unwrap();
|
||||
let serialized = serde_json::to_string(&decoder).unwrap();
|
||||
assert_eq!(serialized, json);
|
||||
}
|
||||
}
|
||||
|
@ -27,6 +27,7 @@ struct Config {
|
||||
continuing_subword_prefix: Option<String>,
|
||||
end_of_word_suffix: Option<String>,
|
||||
fuse_unk: bool,
|
||||
byte_fallback: bool,
|
||||
}
|
||||
|
||||
/// A `BpeBuilder` can be used to create a `BPE` model with a custom configuration.
|
||||
@ -47,6 +48,7 @@ impl Default for BpeBuilder {
|
||||
continuing_subword_prefix: None,
|
||||
end_of_word_suffix: None,
|
||||
fuse_unk: false,
|
||||
byte_fallback: false,
|
||||
},
|
||||
}
|
||||
}
|
||||
@ -115,6 +117,13 @@ impl BpeBuilder {
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the `fuse_unk` option.
|
||||
#[must_use]
|
||||
pub fn byte_fallback(mut self, byte_fallback: bool) -> Self {
|
||||
self.config.byte_fallback = byte_fallback;
|
||||
self
|
||||
}
|
||||
|
||||
/// Returns a `BPE` model that uses the `BpeBuilder`'s configuration.
|
||||
pub fn build(mut self) -> Result<BPE> {
|
||||
// Validate dropout.
|
||||
@ -180,6 +189,7 @@ impl BpeBuilder {
|
||||
continuing_subword_prefix: self.config.continuing_subword_prefix,
|
||||
end_of_word_suffix: self.config.end_of_word_suffix,
|
||||
fuse_unk: self.config.fuse_unk,
|
||||
byte_fallback: self.config.byte_fallback,
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -206,6 +216,9 @@ pub struct BPE {
|
||||
pub end_of_word_suffix: Option<String>,
|
||||
/// Do multiple unk tokens get fused
|
||||
pub fuse_unk: bool,
|
||||
/// Byte fallback from sentence pieces, instead of UNK, uses `"<0x00>"`
|
||||
/// for each byte in the unk token
|
||||
pub byte_fallback: bool,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for BPE {
|
||||
@ -216,6 +229,7 @@ impl std::fmt::Debug for BPE {
|
||||
.field("continuing_subword_prefix", &self.continuing_subword_prefix)
|
||||
.field("end_of_word_suffix", &self.end_of_word_suffix)
|
||||
.field("fuse_unk", &self.fuse_unk)
|
||||
.field("byte_fallback", &self.byte_fallback)
|
||||
.field("vocab", &self.vocab.len())
|
||||
.field("merges", &self.merges.len())
|
||||
.finish()
|
||||
@ -243,6 +257,7 @@ impl Clone for BPE {
|
||||
continuing_subword_prefix: self.continuing_subword_prefix.clone(),
|
||||
end_of_word_suffix: self.end_of_word_suffix.clone(),
|
||||
fuse_unk: self.fuse_unk,
|
||||
byte_fallback: self.byte_fallback,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -373,7 +388,24 @@ impl BPE {
|
||||
unk = None;
|
||||
}
|
||||
word.add(*id, byte_len);
|
||||
} else if let Some(unk_token) = &self.unk_token {
|
||||
} else {
|
||||
if self.byte_fallback {
|
||||
let tokens: Option<Vec<_>> = s
|
||||
.bytes()
|
||||
.map(|b| -> Option<&u32> {
|
||||
let code = format!("<{:#04X}>", b);
|
||||
|
||||
self.vocab.get(&code)
|
||||
})
|
||||
.collect();
|
||||
if let Some(tokens) = tokens {
|
||||
for t in tokens {
|
||||
word.add(*t, 1);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if let Some(unk_token) = &self.unk_token {
|
||||
unk = match (unk, self.fuse_unk) {
|
||||
(Some((unk_id, unk_len)), true) => {
|
||||
// Fuse unk
|
||||
@ -390,15 +422,15 @@ impl BPE {
|
||||
))
|
||||
}
|
||||
_ => Some((
|
||||
*self
|
||||
.vocab
|
||||
.get(unk_token)
|
||||
.ok_or_else(|| Error::UnkTokenOutOfVocabulary(unk_token.to_owned()))?,
|
||||
*self.vocab.get(unk_token).ok_or_else(|| {
|
||||
Error::UnkTokenOutOfVocabulary(unk_token.to_owned())
|
||||
})?,
|
||||
byte_len,
|
||||
)),
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Some((unk_id, unk_len)) = unk {
|
||||
word.add(unk_id, unk_len);
|
||||
}
|
||||
@ -793,4 +825,41 @@ mod tests {
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bpe_byte_fallback() {
|
||||
// 0x61 == 'a' in bytes
|
||||
let vocab: Vocab = [("<unk>".into(), 0), ("<0x61>".into(), 1)]
|
||||
.iter()
|
||||
.cloned()
|
||||
.collect();
|
||||
let bpe = BpeBuilder::default()
|
||||
.vocab_and_merges(vocab, vec![])
|
||||
.unk_token("<unk>".to_string())
|
||||
.byte_fallback(true)
|
||||
.build()
|
||||
.unwrap();
|
||||
let tokens = bpe.tokenize("c").unwrap();
|
||||
assert_eq!(tokens, vec![Token::new(0u32, "<unk>".into(), (0, 1)),]);
|
||||
|
||||
let tokens = bpe.tokenize("a").unwrap();
|
||||
assert_eq!(tokens, vec![Token::new(1u32, "<0x61>".into(), (0, 1)),]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bpe_byte_fallback_newline() {
|
||||
// 0x0A == '\n' in bytes
|
||||
let vocab: Vocab = [("<unk>".into(), 0), ("<0x0A>".into(), 1)]
|
||||
.iter()
|
||||
.cloned()
|
||||
.collect();
|
||||
let bpe = BpeBuilder::default()
|
||||
.vocab_and_merges(vocab, vec![])
|
||||
.unk_token("<unk>".to_string())
|
||||
.byte_fallback(true)
|
||||
.build()
|
||||
.unwrap();
|
||||
let tokens = bpe.tokenize("\n").unwrap();
|
||||
assert_eq!(tokens, vec![Token::new(1u32, "<0x0A>".into(), (0, 1)),]);
|
||||
}
|
||||
}
|
||||
|
@ -20,6 +20,7 @@ impl Serialize for BPE {
|
||||
model.serialize_field("continuing_subword_prefix", &self.continuing_subword_prefix)?;
|
||||
model.serialize_field("end_of_word_suffix", &self.end_of_word_suffix)?;
|
||||
model.serialize_field("fuse_unk", &self.fuse_unk)?;
|
||||
model.serialize_field("byte_fallback", &self.byte_fallback)?;
|
||||
|
||||
// Then the large ones
|
||||
let mut merges: Vec<(&Pair, &u32)> = self
|
||||
@ -55,6 +56,7 @@ impl<'de> Deserialize<'de> for BPE {
|
||||
"continuing_subword_prefix",
|
||||
"end_of_word_suffix",
|
||||
"fuse_unk",
|
||||
"byte_fallback",
|
||||
"vocab",
|
||||
"merges",
|
||||
],
|
||||
@ -105,6 +107,11 @@ impl<'de> Visitor<'de> for BPEVisitor {
|
||||
builder = builder.fuse_unk(suffix);
|
||||
}
|
||||
}
|
||||
"byte_fallback" => {
|
||||
if let Some(suffix) = map.next_value()? {
|
||||
builder = builder.byte_fallback(suffix);
|
||||
}
|
||||
}
|
||||
"vocab" => vocab = Some(map.next_value()?),
|
||||
"merges" => merges = Some(map.next_value()?),
|
||||
"type" => match map.next_value()? {
|
||||
|
Reference in New Issue
Block a user