mirror of
https://github.com/mii443/tokenizers.git
synced 2025-12-03 11:18:29 +00:00
Adding ByteFallback support for tokenizers. (#1183)
* Adding ByteFallback support for `tokenizers`. Two items added: - A flag `byte_fallback` for the `BPE` model. This will be in charge of using `<0x61>` instead of unk on unknown tokens. - A ByteFallback decoder, which will be in charge of putting everything back into string whenever possible. Showing � when the byte decoding fails (behavior checked against LlamaTokenizer in `transformers`. * Update rustdoc. * Clippy + Add BPE(byte_fallback) into bindings. * Stupid file. * Test artifacts removed. * Update stub. * Fix. * Bad file. * CRITICAL FIX: wrapper order because of untagged.... * Remove prints. * Fixing <16 byte fallback.
This commit is contained in:
@@ -4,6 +4,7 @@ from .. import decoders
|
||||
Decoder = decoders.Decoder
|
||||
ByteLevel = decoders.ByteLevel
|
||||
WordPiece = decoders.WordPiece
|
||||
ByteFallback = decoders.ByteFallback
|
||||
Metaspace = decoders.Metaspace
|
||||
BPEDecoder = decoders.BPEDecoder
|
||||
CTC = decoders.CTC
|
||||
|
||||
@@ -45,6 +45,30 @@ class BPEDecoder(Decoder):
|
||||
"""
|
||||
pass
|
||||
|
||||
class ByteFallback(Decoder):
|
||||
"""
|
||||
ByteFallback Decoder
|
||||
ByteFallback is a simple trick which converts tokens looking like `<0x61>`
|
||||
to pure bytes, and attempts to make them into a string. If the tokens
|
||||
cannot be decoded you will get <20> instead for each inconvertable byte token
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
def decode(self, tokens):
|
||||
"""
|
||||
Decode the given list of tokens to a final string
|
||||
|
||||
Args:
|
||||
tokens (:obj:`List[str]`):
|
||||
The list of tokens to decode
|
||||
|
||||
Returns:
|
||||
:obj:`str`: The decoded string
|
||||
"""
|
||||
pass
|
||||
|
||||
class ByteLevel(Decoder):
|
||||
"""
|
||||
ByteLevel Decoder
|
||||
|
||||
@@ -106,6 +106,9 @@ class BPE(Model):
|
||||
|
||||
fuse_unk (:obj:`bool`, `optional`):
|
||||
Whether to fuse any subsequent unknown tokens into a single one
|
||||
|
||||
byte_fallback (:obj:`bool`, `optional`):
|
||||
Whether to use spm byte-fallback trick (defaults to False)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -118,6 +121,7 @@ class BPE(Model):
|
||||
continuing_subword_prefix=None,
|
||||
end_of_word_suffix=None,
|
||||
fuse_unk=None,
|
||||
byte_fallback=False,
|
||||
):
|
||||
pass
|
||||
@staticmethod
|
||||
|
||||
@@ -7,6 +7,7 @@ use pyo3::types::*;
|
||||
use serde::de::Error;
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
use tk::decoders::bpe::BPEDecoder;
|
||||
use tk::decoders::byte_fallback::ByteFallback;
|
||||
use tk::decoders::byte_level::ByteLevel;
|
||||
use tk::decoders::ctc::CTC;
|
||||
use tk::decoders::metaspace::Metaspace;
|
||||
@@ -41,6 +42,9 @@ impl PyDecoder {
|
||||
PyDecoderWrapper::Wrapped(inner) => match &*inner.as_ref().read().unwrap() {
|
||||
DecoderWrapper::Metaspace(_) => Py::new(py, (PyMetaspaceDec {}, base))?.into_py(py),
|
||||
DecoderWrapper::WordPiece(_) => Py::new(py, (PyWordPieceDec {}, base))?.into_py(py),
|
||||
DecoderWrapper::ByteFallback(_) => {
|
||||
Py::new(py, (PyByteFallbackDec {}, base))?.into_py(py)
|
||||
}
|
||||
DecoderWrapper::ByteLevel(_) => Py::new(py, (PyByteLevelDec {}, base))?.into_py(py),
|
||||
DecoderWrapper::BPE(_) => Py::new(py, (PyBPEDecoder {}, base))?.into_py(py),
|
||||
DecoderWrapper::CTC(_) => Py::new(py, (PyCTCDecoder {}, base))?.into_py(py),
|
||||
@@ -196,6 +200,23 @@ impl PyWordPieceDec {
|
||||
}
|
||||
}
|
||||
|
||||
/// ByteFallback Decoder
|
||||
/// ByteFallback is a simple trick which converts tokens looking like `<0x61>`
|
||||
/// to pure bytes, and attempts to make them into a string. If the tokens
|
||||
/// cannot be decoded you will get <20> instead for each inconvertable byte token
|
||||
///
|
||||
#[pyclass(extends=PyDecoder, module = "tokenizers.decoders", name = "ByteFallback")]
|
||||
#[pyo3(text_signature = "(self)")]
|
||||
pub struct PyByteFallbackDec {}
|
||||
#[pymethods]
|
||||
impl PyByteFallbackDec {
|
||||
#[new]
|
||||
#[pyo3(signature = ())]
|
||||
fn new() -> (Self, PyDecoder) {
|
||||
(PyByteFallbackDec {}, ByteFallback::new().into())
|
||||
}
|
||||
}
|
||||
|
||||
/// Metaspace Decoder
|
||||
///
|
||||
/// Args:
|
||||
@@ -453,6 +474,7 @@ pub fn decoders(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||
m.add_class::<PyDecoder>()?;
|
||||
m.add_class::<PyByteLevelDec>()?;
|
||||
m.add_class::<PyWordPieceDec>()?;
|
||||
m.add_class::<PyByteFallbackDec>()?;
|
||||
m.add_class::<PyMetaspaceDec>()?;
|
||||
m.add_class::<PyBPEDecoder>()?;
|
||||
m.add_class::<PyCTCDecoder>()?;
|
||||
|
||||
@@ -249,9 +249,12 @@ impl PyModel {
|
||||
///
|
||||
/// fuse_unk (:obj:`bool`, `optional`):
|
||||
/// Whether to fuse any subsequent unknown tokens into a single one
|
||||
///
|
||||
/// byte_fallback (:obj:`bool`, `optional`):
|
||||
/// Whether to use spm byte-fallback trick (defaults to False)
|
||||
#[pyclass(extends=PyModel, module = "tokenizers.models", name = "BPE")]
|
||||
#[pyo3(
|
||||
text_signature = "(self, vocab=None, merges=None, cache_capacity=None, dropout=None, unk_token=None, continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=None)"
|
||||
text_signature = "(self, vocab=None, merges=None, cache_capacity=None, dropout=None, unk_token=None, continuing_subword_prefix=None, end_of_word_suffix=None, fuse_unk=None, byte_fallback=False)"
|
||||
)]
|
||||
pub struct PyBPE {}
|
||||
|
||||
@@ -277,6 +280,7 @@ impl PyBPE {
|
||||
}
|
||||
"end_of_word_suffix" => builder = builder.end_of_word_suffix(value.extract()?),
|
||||
"fuse_unk" => builder = builder.fuse_unk(value.extract()?),
|
||||
"byte_fallback" => builder = builder.byte_fallback(value.extract()?),
|
||||
_ => println!("Ignored unknown kwarg option {}", key),
|
||||
};
|
||||
}
|
||||
@@ -385,6 +389,16 @@ impl PyBPE {
|
||||
setter!(self_, BPE, fuse_unk, fuse_unk);
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_byte_fallback(self_: PyRef<Self>) -> bool {
|
||||
getter!(self_, BPE, byte_fallback)
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_byte_fallback(self_: PyRef<Self>, byte_fallback: bool) {
|
||||
setter!(self_, BPE, byte_fallback, byte_fallback);
|
||||
}
|
||||
|
||||
#[new]
|
||||
#[pyo3(signature = (vocab=None, merges=None, **kwargs))]
|
||||
fn new(
|
||||
|
||||
@@ -3,7 +3,7 @@ import pickle
|
||||
|
||||
import pytest
|
||||
|
||||
from tokenizers.decoders import CTC, BPEDecoder, ByteLevel, Decoder, Metaspace, Sequence, WordPiece
|
||||
from tokenizers.decoders import CTC, BPEDecoder, ByteLevel, Decoder, Metaspace, Sequence, WordPiece, ByteFallback
|
||||
|
||||
|
||||
class TestByteLevel:
|
||||
@@ -54,6 +54,24 @@ class TestWordPiece:
|
||||
assert decoder.cleanup == True
|
||||
|
||||
|
||||
class TestByteFallback:
|
||||
def test_instantiate(self):
|
||||
assert ByteFallback() is not None
|
||||
assert isinstance(ByteFallback(), Decoder)
|
||||
assert isinstance(ByteFallback(), ByteFallback)
|
||||
assert isinstance(pickle.loads(pickle.dumps(ByteFallback())), ByteFallback)
|
||||
|
||||
def test_decoding(self):
|
||||
decoder = ByteFallback()
|
||||
assert decoder.decode(["My", " na", "me"]) == "My name"
|
||||
assert decoder.decode(["<0x61>"]) == "a"
|
||||
assert decoder.decode(["<0xE5>"]) == "<EFBFBD>"
|
||||
assert decoder.decode(["<0xE5>", "<0x8f>"]) == "<EFBFBD><EFBFBD>"
|
||||
assert decoder.decode(["<0xE5>", "<0x8f>", "<0xab>"]) == "叫"
|
||||
assert decoder.decode(["<0xE5>", "<0x8f>", "a"]) == "<EFBFBD><EFBFBD>a"
|
||||
assert decoder.decode(["<0xE5>", "<0x8f>", "<0xab>", "a"]) == "叫a"
|
||||
|
||||
|
||||
class TestMetaspace:
|
||||
def test_instantiate(self):
|
||||
assert Metaspace() is not None
|
||||
|
||||
@@ -54,6 +54,7 @@ class TestBPE:
|
||||
assert model.continuing_subword_prefix == "__prefix__"
|
||||
assert model.end_of_word_suffix == "__suffix__"
|
||||
assert model.fuse_unk == False
|
||||
assert model.byte_fallback == False
|
||||
|
||||
# Modify these
|
||||
model.dropout = 0.1
|
||||
@@ -66,6 +67,8 @@ class TestBPE:
|
||||
assert model.end_of_word_suffix == "suff"
|
||||
model.fuse_unk = True
|
||||
assert model.fuse_unk == True
|
||||
model.byte_fallback = True
|
||||
assert model.byte_fallback == True
|
||||
|
||||
|
||||
class TestWordPiece:
|
||||
|
||||
Reference in New Issue
Block a user