Adding ByteFallback support for tokenizers. (#1183)

* Adding ByteFallback support for `tokenizers`.

Two items added:

- A flag `byte_fallback` for the `BPE` model. This will be in charge
  of using `<0x61>` instead of unk on unknown tokens.
- A ByteFallback decoder, which will be in charge of putting everything
  back into string whenever possible. Showing � when the byte decoding
  fails (behavior checked against LlamaTokenizer in `transformers`.

* Update rustdoc.

* Clippy + Add BPE(byte_fallback) into bindings.

* Stupid file.

* Test artifacts removed.

* Update stub.

* Fix.

* Bad file.

* CRITICAL FIX: wrapper order because of untagged....

* Remove prints.

* Fixing <16 byte fallback.
This commit is contained in:
Nicolas Patry
2023-03-23 16:04:32 +01:00
committed by GitHub
parent b8fbea00a9
commit 73637a0004
16 changed files with 359 additions and 21 deletions

View File

@@ -4,6 +4,7 @@ from .. import decoders
Decoder = decoders.Decoder
ByteLevel = decoders.ByteLevel
WordPiece = decoders.WordPiece
ByteFallback = decoders.ByteFallback
Metaspace = decoders.Metaspace
BPEDecoder = decoders.BPEDecoder
CTC = decoders.CTC

View File

@@ -45,6 +45,30 @@ class BPEDecoder(Decoder):
"""
pass
class ByteFallback(Decoder):
"""
ByteFallback Decoder
ByteFallback is a simple trick which converts tokens looking like `<0x61>`
to pure bytes, and attempts to make them into a string. If the tokens
cannot be decoded you will get <20> instead for each inconvertable byte token
"""
def __init__(self):
pass
def decode(self, tokens):
"""
Decode the given list of tokens to a final string
Args:
tokens (:obj:`List[str]`):
The list of tokens to decode
Returns:
:obj:`str`: The decoded string
"""
pass
class ByteLevel(Decoder):
"""
ByteLevel Decoder

View File

@@ -106,6 +106,9 @@ class BPE(Model):
fuse_unk (:obj:`bool`, `optional`):
Whether to fuse any subsequent unknown tokens into a single one
byte_fallback (:obj:`bool`, `optional`):
Whether to use spm byte-fallback trick (defaults to False)
"""
def __init__(
@@ -118,6 +121,7 @@ class BPE(Model):
continuing_subword_prefix=None,
end_of_word_suffix=None,
fuse_unk=None,
byte_fallback=False,
):
pass
@staticmethod

View File

@@ -7,6 +7,7 @@ use pyo3::types::*;
use serde::de::Error;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use tk::decoders::bpe::BPEDecoder;
use tk::decoders::byte_fallback::ByteFallback;
use tk::decoders::byte_level::ByteLevel;
use tk::decoders::ctc::CTC;
use tk::decoders::metaspace::Metaspace;
@@ -41,6 +42,9 @@ impl PyDecoder {
PyDecoderWrapper::Wrapped(inner) => match &*inner.as_ref().read().unwrap() {
DecoderWrapper::Metaspace(_) => Py::new(py, (PyMetaspaceDec {}, base))?.into_py(py),
DecoderWrapper::WordPiece(_) => Py::new(py, (PyWordPieceDec {}, base))?.into_py(py),
DecoderWrapper::ByteFallback(_) => {
Py::new(py, (PyByteFallbackDec {}, base))?.into_py(py)
}
DecoderWrapper::ByteLevel(_) => Py::new(py, (PyByteLevelDec {}, base))?.into_py(py),
DecoderWrapper::BPE(_) => Py::new(py, (PyBPEDecoder {}, base))?.into_py(py),
DecoderWrapper::CTC(_) => Py::new(py, (PyCTCDecoder {}, base))?.into_py(py),
@@ -196,6 +200,23 @@ impl PyWordPieceDec {
}
}
/// ByteFallback Decoder
/// ByteFallback is a simple trick which converts tokens looking like `<0x61>`
/// to pure bytes, and attempts to make them into a string. If the tokens
/// cannot be decoded you will get <20> instead for each inconvertable byte token
///
#[pyclass(extends=PyDecoder, module = "tokenizers.decoders", name = "ByteFallback")]
#[pyo3(text_signature = "(self)")]
pub struct PyByteFallbackDec {}
#[pymethods]
impl PyByteFallbackDec {
#[new]
#[pyo3(signature = ())]
fn new() -> (Self, PyDecoder) {
(PyByteFallbackDec {}, ByteFallback::new().into())
}
}
/// Metaspace Decoder
///
/// Args:
@@ -453,6 +474,7 @@ pub fn decoders(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<PyDecoder>()?;
m.add_class::<PyByteLevelDec>()?;
m.add_class::<PyWordPieceDec>()?;
m.add_class::<PyByteFallbackDec>()?;
m.add_class::<PyMetaspaceDec>()?;
m.add_class::<PyBPEDecoder>()?;
m.add_class::<PyCTCDecoder>()?;

View File

@@ -249,9 +249,12 @@ impl PyModel {
///
/// fuse_unk (:obj:`bool`, `optional`):
/// Whether to fuse any subsequent unknown tokens into a single one
///
/// byte_fallback (:obj:`bool`, `optional`):
/// Whether to use spm byte-fallback trick (defaults to False)
#[pyclass(extends=PyModel, module = "tokenizers.models", name = "BPE")]
#[pyo3(
text_signature = "(self, vocab=None, merges=None, cache_capacity=None, dropout=None, unk_token=None, continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=None)"
text_signature = "(self, vocab=None, merges=None, cache_capacity=None, dropout=None, unk_token=None, continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=None, byte_fallback=False)"
)]
pub struct PyBPE {}
@@ -277,6 +280,7 @@ impl PyBPE {
}
"end_of_word_suffix" => builder = builder.end_of_word_suffix(value.extract()?),
"fuse_unk" => builder = builder.fuse_unk(value.extract()?),
"byte_fallback" => builder = builder.byte_fallback(value.extract()?),
_ => println!("Ignored unknown kwarg option {}", key),
};
}
@@ -385,6 +389,16 @@ impl PyBPE {
setter!(self_, BPE, fuse_unk, fuse_unk);
}
#[getter]
fn get_byte_fallback(self_: PyRef<Self>) -> bool {
getter!(self_, BPE, byte_fallback)
}
#[setter]
fn set_byte_fallback(self_: PyRef<Self>, byte_fallback: bool) {
setter!(self_, BPE, byte_fallback, byte_fallback);
}
#[new]
#[pyo3(signature = (vocab=None, merges=None, **kwargs))]
fn new(

View File

@@ -3,7 +3,7 @@ import pickle
import pytest
from tokenizers.decoders import CTC, BPEDecoder, ByteLevel, Decoder, Metaspace, Sequence, WordPiece
from tokenizers.decoders import CTC, BPEDecoder, ByteLevel, Decoder, Metaspace, Sequence, WordPiece, ByteFallback
class TestByteLevel:
@@ -54,6 +54,24 @@ class TestWordPiece:
assert decoder.cleanup == True
class TestByteFallback:
def test_instantiate(self):
assert ByteFallback() is not None
assert isinstance(ByteFallback(), Decoder)
assert isinstance(ByteFallback(), ByteFallback)
assert isinstance(pickle.loads(pickle.dumps(ByteFallback())), ByteFallback)
def test_decoding(self):
decoder = ByteFallback()
assert decoder.decode(["My", " na", "me"]) == "My name"
assert decoder.decode(["<0x61>"]) == "a"
assert decoder.decode(["<0xE5>"]) == "<EFBFBD>"
assert decoder.decode(["<0xE5>", "<0x8f>"]) == "<EFBFBD><EFBFBD>"
assert decoder.decode(["<0xE5>", "<0x8f>", "<0xab>"]) == ""
assert decoder.decode(["<0xE5>", "<0x8f>", "a"]) == "<EFBFBD><EFBFBD>a"
assert decoder.decode(["<0xE5>", "<0x8f>", "<0xab>", "a"]) == "叫a"
class TestMetaspace:
def test_instantiate(self):
assert Metaspace() is not None

View File

@@ -54,6 +54,7 @@ class TestBPE:
assert model.continuing_subword_prefix == "__prefix__"
assert model.end_of_word_suffix == "__suffix__"
assert model.fuse_unk == False
assert model.byte_fallback == False
# Modify these
model.dropout = 0.1
@@ -66,6 +67,8 @@ class TestBPE:
assert model.end_of_word_suffix == "suff"
model.fuse_unk = True
assert model.fuse_unk == True
model.byte_fallback = True
assert model.byte_fallback == True
class TestWordPiece: