Adding ByteFallback support for tokenizers. (#1183)

* Adding ByteFallback support for `tokenizers`.

Two items added:

- A flag `byte_fallback` for the `BPE` model. This will be in charge
  of using `<0x61>` instead of unk on unknown tokens.
- A ByteFallback decoder, which will be in charge of putting everything
  back into string whenever possible. Showing � when the byte decoding
  fails (behavior checked against LlamaTokenizer in `transformers`.

* Update rustdoc.

* Clippy + Add BPE(byte_fallback) into bindings.

* Stupid file.

* Test artifacts removed.

* Update stub.

* Fix.

* Bad file.

* CRITICAL FIX: wrapper order because of untagged....

* Remove prints.

* Fixing <16 byte fallback.
This commit is contained in:
Nicolas Patry
2023-03-23 16:04:32 +01:00
committed by GitHub
parent b8fbea00a9
commit 73637a0004
16 changed files with 359 additions and 21 deletions

View File

@ -20,6 +20,14 @@ export function byteLevelDecoder(): Decoder;
*/
export function wordPieceDecoder(prefix?: string, cleanup?: boolean): Decoder;
/**
* Instantiate a new ByteFallback Decoder
* ByteFallback is a simple trick which converts tokens looking like `<0x61>`
* to pure bytes, and attempts to make them into a string. If the tokens
* cannot be decoded you will get <20> instead for each inconvertable byte token
*/
export function byteFallbackDecoder(): Decoder;
/**
* Instantiate a new Metaspace
*

View File

@ -3,6 +3,7 @@ const native = require("./native");
module.exports = {
byteLevelDecoder: native.decoders_ByteLevel,
wordPieceDecoder: native.decoders_WordPiece,
byteFallbackDecoder: native.decoders_ByteFallback,
metaspaceDecoder: native.decoders_Metaspace,
bpeDecoder: native.decoders_BPEDecoder,
ctcDecoder: native.decoders_CTC,

View File

@ -1,5 +1,6 @@
import {
bpeDecoder,
byteFallbackDecoder,
ctcDecoder,
metaspaceDecoder,
sequenceDecoder,
@ -22,6 +23,27 @@ describe("wordPieceDecoder", () => {
});
});
describe("byteFallbackDecoder", () => {
it("accepts `undefined` as first parameter", () => {
expect(byteFallbackDecoder()).toBeDefined();
});
it("can decode arrays of strings", () => {
expect(byteFallbackDecoder().decode(["Hel", "lo"])).toEqual("Hello");
expect(byteFallbackDecoder().decode(["<0x61>"])).toEqual("a");
expect(byteFallbackDecoder().decode(["<0x61>"])).toEqual("a");
expect(byteFallbackDecoder().decode(["My", " na", "me"])).toEqual("My name");
expect(byteFallbackDecoder().decode(["<0x61>"])).toEqual("a");
expect(byteFallbackDecoder().decode(["<0xE5>"])).toEqual("<22>");
expect(byteFallbackDecoder().decode(["<0xE5>", "<0x8f>"])).toEqual("<22><>");
expect(byteFallbackDecoder().decode(["<0xE5>", "<0x8f>", "<0xab>"])).toEqual("叫");
expect(byteFallbackDecoder().decode(["<0xE5>", "<0x8f>", "a"])).toEqual("<22><>a");
expect(byteFallbackDecoder().decode(["<0xE5>", "<0x8f>", "<0xab>", "a"])).toEqual(
"叫a"
);
});
});
describe("metaspaceDecoder", () => {
it("accepts `undefined` as first parameter", () => {
expect(metaspaceDecoder(undefined)).toBeDefined();

View File

@ -72,6 +72,16 @@ fn wordpiece(mut cx: FunctionContext) -> JsResult<JsDecoder> {
Ok(decoder)
}
/// byte_fallback()
fn byte_fallback(mut cx: FunctionContext) -> JsResult<JsDecoder> {
let mut decoder = JsDecoder::new::<_, JsDecoder, _>(&mut cx, vec![])?;
let guard = cx.lock();
decoder.borrow_mut(&guard).decoder = Some(Arc::new(
tk::decoders::byte_fallback::ByteFallback::new().into(),
));
Ok(decoder)
}
/// metaspace(replacement: String = "_", add_prefix_space: bool = true)
fn metaspace(mut cx: FunctionContext) -> JsResult<JsDecoder> {
let replacement = cx.extract_opt::<char>(0)?.unwrap_or('▁');
@ -147,6 +157,7 @@ fn sequence(mut cx: FunctionContext) -> JsResult<JsDecoder> {
pub fn register(m: &mut ModuleContext, prefix: &str) -> NeonResult<()> {
m.export_function(&format!("{}_ByteLevel", prefix), byte_level)?;
m.export_function(&format!("{}_WordPiece", prefix), wordpiece)?;
m.export_function(&format!("{}_ByteFallback", prefix), byte_fallback)?;
m.export_function(&format!("{}_Metaspace", prefix), metaspace)?;
m.export_function(&format!("{}_BPEDecoder", prefix), bpe_decoder)?;
m.export_function(&format!("{}_CTC", prefix), ctc_decoder)?;

View File

@ -132,6 +132,7 @@ struct BpeOptions {
continuing_subword_prefix: Option<String>,
end_of_word_suffix: Option<String>,
fuse_unk: Option<bool>,
byte_fallback: Option<bool>,
}
impl BpeOptions {
fn apply_to_bpe_builder(self, mut builder: BpeBuilder) -> BpeBuilder {
@ -153,6 +154,9 @@ impl BpeOptions {
if let Some(fuse_unk) = self.fuse_unk {
builder = builder.fuse_unk(fuse_unk);
}
if let Some(byte_fallback) = self.byte_fallback {
builder = builder.byte_fallback(byte_fallback);
}
builder
}

View File

@ -4,6 +4,7 @@ from .. import decoders
Decoder = decoders.Decoder
ByteLevel = decoders.ByteLevel
WordPiece = decoders.WordPiece
ByteFallback = decoders.ByteFallback
Metaspace = decoders.Metaspace
BPEDecoder = decoders.BPEDecoder
CTC = decoders.CTC

View File

@ -45,6 +45,30 @@ class BPEDecoder(Decoder):
"""
pass
class ByteFallback(Decoder):
"""
ByteFallback Decoder
ByteFallback is a simple trick which converts tokens looking like `<0x61>`
to pure bytes, and attempts to make them into a string. If the tokens
cannot be decoded you will get <20> instead for each inconvertable byte token
"""
def __init__(self):
pass
def decode(self, tokens):
"""
Decode the given list of tokens to a final string
Args:
tokens (:obj:`List[str]`):
The list of tokens to decode
Returns:
:obj:`str`: The decoded string
"""
pass
class ByteLevel(Decoder):
"""
ByteLevel Decoder

View File

@ -106,6 +106,9 @@ class BPE(Model):
fuse_unk (:obj:`bool`, `optional`):
Whether to fuse any subsequent unknown tokens into a single one
byte_fallback (:obj:`bool`, `optional`):
Whether to use spm byte-fallback trick (defaults to False)
"""
def __init__(
@ -118,6 +121,7 @@ class BPE(Model):
continuing_subword_prefix=None,
end_of_word_suffix=None,
fuse_unk=None,
byte_fallback=False,
):
pass
@staticmethod

View File

@ -7,6 +7,7 @@ use pyo3::types::*;
use serde::de::Error;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use tk::decoders::bpe::BPEDecoder;
use tk::decoders::byte_fallback::ByteFallback;
use tk::decoders::byte_level::ByteLevel;
use tk::decoders::ctc::CTC;
use tk::decoders::metaspace::Metaspace;
@ -41,6 +42,9 @@ impl PyDecoder {
PyDecoderWrapper::Wrapped(inner) => match &*inner.as_ref().read().unwrap() {
DecoderWrapper::Metaspace(_) => Py::new(py, (PyMetaspaceDec {}, base))?.into_py(py),
DecoderWrapper::WordPiece(_) => Py::new(py, (PyWordPieceDec {}, base))?.into_py(py),
DecoderWrapper::ByteFallback(_) => {
Py::new(py, (PyByteFallbackDec {}, base))?.into_py(py)
}
DecoderWrapper::ByteLevel(_) => Py::new(py, (PyByteLevelDec {}, base))?.into_py(py),
DecoderWrapper::BPE(_) => Py::new(py, (PyBPEDecoder {}, base))?.into_py(py),
DecoderWrapper::CTC(_) => Py::new(py, (PyCTCDecoder {}, base))?.into_py(py),
@ -196,6 +200,23 @@ impl PyWordPieceDec {
}
}
/// ByteFallback Decoder
/// ByteFallback is a simple trick which converts tokens looking like `<0x61>`
/// to pure bytes, and attempts to make them into a string. If the tokens
/// cannot be decoded you will get <20> instead for each inconvertable byte token
///
#[pyclass(extends=PyDecoder, module = "tokenizers.decoders", name = "ByteFallback")]
#[pyo3(text_signature = "(self)")]
pub struct PyByteFallbackDec {}
#[pymethods]
impl PyByteFallbackDec {
#[new]
#[pyo3(signature = ())]
fn new() -> (Self, PyDecoder) {
(PyByteFallbackDec {}, ByteFallback::new().into())
}
}
/// Metaspace Decoder
///
/// Args:
@ -453,6 +474,7 @@ pub fn decoders(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<PyDecoder>()?;
m.add_class::<PyByteLevelDec>()?;
m.add_class::<PyWordPieceDec>()?;
m.add_class::<PyByteFallbackDec>()?;
m.add_class::<PyMetaspaceDec>()?;
m.add_class::<PyBPEDecoder>()?;
m.add_class::<PyCTCDecoder>()?;

View File

@ -249,9 +249,12 @@ impl PyModel {
///
/// fuse_unk (:obj:`bool`, `optional`):
/// Whether to fuse any subsequent unknown tokens into a single one
///
/// byte_fallback (:obj:`bool`, `optional`):
/// Whether to use spm byte-fallback trick (defaults to False)
#[pyclass(extends=PyModel, module = "tokenizers.models", name = "BPE")]
#[pyo3(
text_signature = "(self, vocab=None, merges=None, cache_capacity=None, dropout=None, unk_token=None, continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=None)"
text_signature = "(self, vocab=None, merges=None, cache_capacity=None, dropout=None, unk_token=None, continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=None, byte_fallback=False)"
)]
pub struct PyBPE {}
@ -277,6 +280,7 @@ impl PyBPE {
}
"end_of_word_suffix" => builder = builder.end_of_word_suffix(value.extract()?),
"fuse_unk" => builder = builder.fuse_unk(value.extract()?),
"byte_fallback" => builder = builder.byte_fallback(value.extract()?),
_ => println!("Ignored unknown kwarg option {}", key),
};
}
@ -385,6 +389,16 @@ impl PyBPE {
setter!(self_, BPE, fuse_unk, fuse_unk);
}
#[getter]
fn get_byte_fallback(self_: PyRef<Self>) -> bool {
getter!(self_, BPE, byte_fallback)
}
#[setter]
fn set_byte_fallback(self_: PyRef<Self>, byte_fallback: bool) {
setter!(self_, BPE, byte_fallback, byte_fallback);
}
#[new]
#[pyo3(signature = (vocab=None, merges=None, **kwargs))]
fn new(

View File

@ -3,7 +3,7 @@ import pickle
import pytest
from tokenizers.decoders import CTC, BPEDecoder, ByteLevel, Decoder, Metaspace, Sequence, WordPiece
from tokenizers.decoders import CTC, BPEDecoder, ByteLevel, Decoder, Metaspace, Sequence, WordPiece, ByteFallback
class TestByteLevel:
@ -54,6 +54,24 @@ class TestWordPiece:
assert decoder.cleanup == True
class TestByteFallback:
def test_instantiate(self):
assert ByteFallback() is not None
assert isinstance(ByteFallback(), Decoder)
assert isinstance(ByteFallback(), ByteFallback)
assert isinstance(pickle.loads(pickle.dumps(ByteFallback())), ByteFallback)
def test_decoding(self):
decoder = ByteFallback()
assert decoder.decode(["My", " na", "me"]) == "My name"
assert decoder.decode(["<0x61>"]) == "a"
assert decoder.decode(["<0xE5>"]) == "<EFBFBD>"
assert decoder.decode(["<0xE5>", "<0x8f>"]) == "<EFBFBD><EFBFBD>"
assert decoder.decode(["<0xE5>", "<0x8f>", "<0xab>"]) == ""
assert decoder.decode(["<0xE5>", "<0x8f>", "a"]) == "<EFBFBD><EFBFBD>a"
assert decoder.decode(["<0xE5>", "<0x8f>", "<0xab>", "a"]) == "叫a"
class TestMetaspace:
def test_instantiate(self):
assert Metaspace() is not None

View File

@ -54,6 +54,7 @@ class TestBPE:
assert model.continuing_subword_prefix == "__prefix__"
assert model.end_of_word_suffix == "__suffix__"
assert model.fuse_unk == False
assert model.byte_fallback == False
# Modify these
model.dropout = 0.1
@ -66,6 +67,8 @@ class TestBPE:
assert model.end_of_word_suffix == "suff"
model.fuse_unk = True
assert model.fuse_unk == True
model.byte_fallback = True
assert model.byte_fallback == True
class TestWordPiece:

View File

@ -0,0 +1,108 @@
use crate::tokenizer::{Decoder, Result};
use serde::{Deserialize, Serialize};
#[derive(Deserialize, Clone, Debug, Serialize, Default)]
/// ByteFallback is a simple trick which converts tokens looking like `<0x61>`
/// to pure bytes, and attempts to make them into a string. If the tokens
/// cannot be decoded you will get <20> instead for each inconvertable byte token
#[serde(tag = "type")]
#[non_exhaustive]
pub struct ByteFallback {}
impl ByteFallback {
pub fn new() -> Self {
Self {}
}
}
impl Decoder for ByteFallback {
fn decode_chain(&self, tokens: Vec<String>) -> Result<Vec<String>> {
let mut new_tokens: Vec<String> = vec![];
let mut previous_byte_tokens: Vec<u8> = vec![];
for token in tokens {
let bytes = if token.len() == 6 && token.starts_with("<0x") && token.ends_with('>') {
if let Ok(byte) = u8::from_str_radix(&token[3..5], 16) {
Some(byte)
} else {
None
}
} else {
None
};
if let Some(bytes) = bytes {
previous_byte_tokens.push(bytes);
} else {
if !previous_byte_tokens.is_empty() {
if let Ok(string) = String::from_utf8(previous_byte_tokens.clone()) {
new_tokens.push(string);
} else {
for _ in 0..previous_byte_tokens.len() {
new_tokens.push("<EFBFBD>".into());
}
}
previous_byte_tokens.clear();
}
new_tokens.push(token);
}
}
if !previous_byte_tokens.is_empty() {
if let Ok(string) = String::from_utf8(previous_byte_tokens.clone()) {
new_tokens.push(string);
} else {
for _ in 0..previous_byte_tokens.len() {
new_tokens.push("<EFBFBD>".into());
}
}
}
Ok(new_tokens)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn decode() {
let decoder = ByteFallback::new();
let res = decoder
.decode_chain(vec!["Hey".into(), "friend!".into()])
.unwrap();
assert_eq!(res, vec!["Hey", "friend!"]);
let res = decoder.decode_chain(vec!["<0x61>".into()]).unwrap();
assert_eq!(res, vec!["a"]);
let res = decoder.decode_chain(vec!["<0xE5>".into()]).unwrap();
assert_eq!(res, vec!["<EFBFBD>"]);
let res = decoder
.decode_chain(vec!["<0xE5>".into(), "<0x8f>".into()])
.unwrap();
assert_eq!(res, vec!["<EFBFBD>", "<EFBFBD>"]);
// 叫
let res = decoder
.decode_chain(vec!["<0xE5>".into(), "<0x8f>".into(), "<0xab>".into()])
.unwrap();
assert_eq!(res, vec![""]);
let res = decoder
.decode_chain(vec![
"<0xE5>".into(),
"<0x8f>".into(),
"<0xab>".into(),
"a".into(),
])
.unwrap();
assert_eq!(res, vec!["", "a"]);
let res = decoder
.decode_chain(vec!["<0xE5>".into(), "<0x8f>".into(), "a".into()])
.unwrap();
assert_eq!(res, vec!["<EFBFBD>", "<EFBFBD>", "a"]);
}
}

View File

@ -1,4 +1,5 @@
pub mod bpe;
pub mod byte_fallback;
pub mod ctc;
pub mod sequence;
pub mod wordpiece;
@ -10,6 +11,7 @@ pub use super::pre_tokenizers::metaspace;
use serde::{Deserialize, Serialize};
use crate::decoders::bpe::BPEDecoder;
use crate::decoders::byte_fallback::ByteFallback;
use crate::decoders::ctc::CTC;
use crate::decoders::sequence::Sequence;
use crate::decoders::wordpiece::WordPiece;
@ -26,6 +28,11 @@ pub enum DecoderWrapper {
Metaspace(Metaspace),
CTC(CTC),
Sequence(Sequence),
// XXX: This is an untagged enum, which unfortunately means order
// is **CRITICAL**. We absolutely need to make sure order is correct.
// Since byte fallback is parameter free, is **has** to be last, and will
// unfortunately match pretty much everything.
ByteFallback(ByteFallback),
}
impl Decoder for DecoderWrapper {
@ -37,13 +44,28 @@ impl Decoder for DecoderWrapper {
Self::WordPiece(wp) => wp.decode_chain(tokens),
Self::CTC(ctc) => ctc.decode_chain(tokens),
Self::Sequence(seq) => seq.decode_chain(tokens),
Self::ByteFallback(bf) => bf.decode_chain(tokens),
}
}
}
impl_enum_from!(BPEDecoder, DecoderWrapper, BPE);
impl_enum_from!(ByteLevel, DecoderWrapper, ByteLevel);
impl_enum_from!(ByteFallback, DecoderWrapper, ByteFallback);
impl_enum_from!(Metaspace, DecoderWrapper, Metaspace);
impl_enum_from!(WordPiece, DecoderWrapper, WordPiece);
impl_enum_from!(CTC, DecoderWrapper, CTC);
impl_enum_from!(Sequence, DecoderWrapper, Sequence);
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn decoder_serialization() {
let json = r#"{"type":"Sequence","decoders":[{"type":"ByteFallback"},{"type":"Metaspace","replacement":"▁","add_prefix_space":true}]}"#;
let decoder: DecoderWrapper = serde_json::from_str(json).unwrap();
let serialized = serde_json::to_string(&decoder).unwrap();
assert_eq!(serialized, json);
}
}

View File

@ -27,6 +27,7 @@ struct Config {
continuing_subword_prefix: Option<String>,
end_of_word_suffix: Option<String>,
fuse_unk: bool,
byte_fallback: bool,
}
/// A `BpeBuilder` can be used to create a `BPE` model with a custom configuration.
@ -47,6 +48,7 @@ impl Default for BpeBuilder {
continuing_subword_prefix: None,
end_of_word_suffix: None,
fuse_unk: false,
byte_fallback: false,
},
}
}
@ -115,6 +117,13 @@ impl BpeBuilder {
self
}
/// Set the `fuse_unk` option.
#[must_use]
pub fn byte_fallback(mut self, byte_fallback: bool) -> Self {
self.config.byte_fallback = byte_fallback;
self
}
/// Returns a `BPE` model that uses the `BpeBuilder`'s configuration.
pub fn build(mut self) -> Result<BPE> {
// Validate dropout.
@ -180,6 +189,7 @@ impl BpeBuilder {
continuing_subword_prefix: self.config.continuing_subword_prefix,
end_of_word_suffix: self.config.end_of_word_suffix,
fuse_unk: self.config.fuse_unk,
byte_fallback: self.config.byte_fallback,
})
}
}
@ -206,6 +216,9 @@ pub struct BPE {
pub end_of_word_suffix: Option<String>,
/// Do multiple unk tokens get fused
pub fuse_unk: bool,
/// Byte fallback from sentence pieces, instead of UNK, uses `"<0x00>"`
/// for each byte in the unk token
pub byte_fallback: bool,
}
impl std::fmt::Debug for BPE {
@ -216,6 +229,7 @@ impl std::fmt::Debug for BPE {
.field("continuing_subword_prefix", &self.continuing_subword_prefix)
.field("end_of_word_suffix", &self.end_of_word_suffix)
.field("fuse_unk", &self.fuse_unk)
.field("byte_fallback", &self.byte_fallback)
.field("vocab", &self.vocab.len())
.field("merges", &self.merges.len())
.finish()
@ -243,6 +257,7 @@ impl Clone for BPE {
continuing_subword_prefix: self.continuing_subword_prefix.clone(),
end_of_word_suffix: self.end_of_word_suffix.clone(),
fuse_unk: self.fuse_unk,
byte_fallback: self.byte_fallback,
}
}
}
@ -373,7 +388,24 @@ impl BPE {
unk = None;
}
word.add(*id, byte_len);
} else if let Some(unk_token) = &self.unk_token {
} else {
if self.byte_fallback {
let tokens: Option<Vec<_>> = s
.bytes()
.map(|b| -> Option<&u32> {
let code = format!("<{:#04X}>", b);
self.vocab.get(&code)
})
.collect();
if let Some(tokens) = tokens {
for t in tokens {
word.add(*t, 1);
}
continue;
}
}
if let Some(unk_token) = &self.unk_token {
unk = match (unk, self.fuse_unk) {
(Some((unk_id, unk_len)), true) => {
// Fuse unk
@ -390,15 +422,15 @@ impl BPE {
))
}
_ => Some((
*self
.vocab
.get(unk_token)
.ok_or_else(|| Error::UnkTokenOutOfVocabulary(unk_token.to_owned()))?,
*self.vocab.get(unk_token).ok_or_else(|| {
Error::UnkTokenOutOfVocabulary(unk_token.to_owned())
})?,
byte_len,
)),
};
}
}
}
if let Some((unk_id, unk_len)) = unk {
word.add(unk_id, unk_len);
}
@ -793,4 +825,41 @@ mod tests {
},
}
}
#[test]
fn test_bpe_byte_fallback() {
// 0x61 == 'a' in bytes
let vocab: Vocab = [("<unk>".into(), 0), ("<0x61>".into(), 1)]
.iter()
.cloned()
.collect();
let bpe = BpeBuilder::default()
.vocab_and_merges(vocab, vec![])
.unk_token("<unk>".to_string())
.byte_fallback(true)
.build()
.unwrap();
let tokens = bpe.tokenize("c").unwrap();
assert_eq!(tokens, vec![Token::new(0u32, "<unk>".into(), (0, 1)),]);
let tokens = bpe.tokenize("a").unwrap();
assert_eq!(tokens, vec![Token::new(1u32, "<0x61>".into(), (0, 1)),]);
}
#[test]
fn test_bpe_byte_fallback_newline() {
// 0x0A == '\n' in bytes
let vocab: Vocab = [("<unk>".into(), 0), ("<0x0A>".into(), 1)]
.iter()
.cloned()
.collect();
let bpe = BpeBuilder::default()
.vocab_and_merges(vocab, vec![])
.unk_token("<unk>".to_string())
.byte_fallback(true)
.build()
.unwrap();
let tokens = bpe.tokenize("\n").unwrap();
assert_eq!(tokens, vec![Token::new(1u32, "<0x0A>".into(), (0, 1)),]);
}
}

View File

@ -20,6 +20,7 @@ impl Serialize for BPE {
model.serialize_field("continuing_subword_prefix", &self.continuing_subword_prefix)?;
model.serialize_field("end_of_word_suffix", &self.end_of_word_suffix)?;
model.serialize_field("fuse_unk", &self.fuse_unk)?;
model.serialize_field("byte_fallback", &self.byte_fallback)?;
// Then the large ones
let mut merges: Vec<(&Pair, &u32)> = self
@ -55,6 +56,7 @@ impl<'de> Deserialize<'de> for BPE {
"continuing_subword_prefix",
"end_of_word_suffix",
"fuse_unk",
"byte_fallback",
"vocab",
"merges",
],
@ -105,6 +107,11 @@ impl<'de> Visitor<'de> for BPEVisitor {
builder = builder.fuse_unk(suffix);
}
}
"byte_fallback" => {
if let Some(suffix) = map.next_value()? {
builder = builder.byte_fallback(suffix);
}
}
"vocab" => vocab = Some(map.next_value()?),
"merges" => merges = Some(map.next_value()?),
"type" => match map.next_value()? {