Adding 2 new decoders: (#1196)

* Adding 2 new decoders:

- Fuse will simply concatenate all tokens into 1 string
- Strip will remove n char from left or right

Sequence(Replace("_", " "), Fuse(), Strip(1, 0)) should be what we want
for the `Metaspace` thing.

- Note: Added a new dependency from better parsing of decoders.
This is due to untagged enums which can match anything the `MustBe`
ensure there's no issue between Fuse and ByteFallback.
Since both are new the chances for backward incompatibility is low.

* Fixing picking/unpickling (using default args.).

* Stub.

* Black.

* Fixing node.
This commit is contained in:
Nicolas Patry
2023-03-24 00:50:54 +01:00
committed by GitHub
parent d2c8190a0f
commit e4aea890d5
13 changed files with 311 additions and 7 deletions

View File

@@ -6,6 +6,8 @@ ByteLevel = decoders.ByteLevel
Replace = decoders.Replace
WordPiece = decoders.WordPiece
ByteFallback = decoders.ByteFallback
Fuse = decoders.Fuse
Strip = decoders.Strip
Metaspace = decoders.Metaspace
BPEDecoder = decoders.BPEDecoder
CTC = decoders.CTC

View File

@@ -121,6 +121,29 @@ class CTC(Decoder):
"""
pass
class Fuse(Decoder):
"""
Fuse Decoder
Fuse simply fuses every token into a single string.
This is the last step of decoding, this decoder exists only if
there is need to add other decoders *after* the fusion
"""
def __init__(self):
pass
def decode(self, tokens):
"""
Decode the given list of tokens to a final string
Args:
tokens (:obj:`List[str]`):
The list of tokens to decode
Returns:
:obj:`str`: The decoded string
"""
pass
class Metaspace(Decoder):
"""
Metaspace Decoder
@@ -197,6 +220,27 @@ class Sequence(Decoder):
"""
pass
class Strip(Decoder):
"""
Strip normalizer
Strips n left characters of each token, or n right characters of each token
"""
def __init__(self, left=0, right=0):
pass
def decode(self, tokens):
"""
Decode the given list of tokens to a final string
Args:
tokens (:obj:`List[str]`):
The list of tokens to decode
Returns:
:obj:`str`: The decoded string
"""
pass
class WordPiece(Decoder):
"""
WordPiece Decoder

View File

@@ -11,8 +11,10 @@ use tk::decoders::bpe::BPEDecoder;
use tk::decoders::byte_fallback::ByteFallback;
use tk::decoders::byte_level::ByteLevel;
use tk::decoders::ctc::CTC;
use tk::decoders::fuse::Fuse;
use tk::decoders::metaspace::Metaspace;
use tk::decoders::sequence::Sequence;
use tk::decoders::strip::Strip;
use tk::decoders::wordpiece::WordPiece;
use tk::decoders::DecoderWrapper;
use tk::normalizers::replace::Replace;
@@ -47,6 +49,8 @@ impl PyDecoder {
DecoderWrapper::ByteFallback(_) => {
Py::new(py, (PyByteFallbackDec {}, base))?.into_py(py)
}
DecoderWrapper::Strip(_) => Py::new(py, (PyStrip {}, base))?.into_py(py),
DecoderWrapper::Fuse(_) => Py::new(py, (PyFuseDec {}, base))?.into_py(py),
DecoderWrapper::ByteLevel(_) => Py::new(py, (PyByteLevelDec {}, base))?.into_py(py),
DecoderWrapper::Replace(_) => Py::new(py, (PyReplaceDec {}, base))?.into_py(py),
DecoderWrapper::BPE(_) => Py::new(py, (PyBPEDecoder {}, base))?.into_py(py),
@@ -238,6 +242,56 @@ impl PyByteFallbackDec {
}
}
/// Fuse Decoder
/// Fuse simply fuses every token into a single string.
/// This is the last step of decoding, this decoder exists only if
/// there is need to add other decoders *after* the fusion
#[pyclass(extends=PyDecoder, module = "tokenizers.decoders", name = "Fuse")]
#[pyo3(text_signature = "(self)")]
pub struct PyFuseDec {}
#[pymethods]
impl PyFuseDec {
#[new]
#[pyo3(signature = ())]
fn new() -> (Self, PyDecoder) {
(PyFuseDec {}, Fuse::new().into())
}
}
/// Strip normalizer
/// Strips n left characters of each token, or n right characters of each token
#[pyclass(extends=PyDecoder, module = "tokenizers.decoders", name = "Strip")]
#[pyo3(text_signature = "(self, left=0, right=0)")]
pub struct PyStrip {}
#[pymethods]
impl PyStrip {
#[getter]
fn get_left(self_: PyRef<Self>) -> usize {
getter!(self_, Strip, left)
}
#[setter]
fn set_left(self_: PyRef<Self>, left: usize) {
setter!(self_, Strip, left, left)
}
#[getter]
fn get_right(self_: PyRef<Self>) -> usize {
getter!(self_, Strip, right)
}
#[setter]
fn set_right(self_: PyRef<Self>, right: usize) {
setter!(self_, Strip, right, right)
}
#[new]
#[pyo3(signature = (left=0, right=0))]
fn new(left: usize, right: usize) -> (Self, PyDecoder) {
(PyStrip {}, Strip::new(left, right).into())
}
}
/// Metaspace Decoder
///
/// Args:
@@ -497,6 +551,8 @@ pub fn decoders(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<PyReplaceDec>()?;
m.add_class::<PyWordPieceDec>()?;
m.add_class::<PyByteFallbackDec>()?;
m.add_class::<PyFuseDec>()?;
m.add_class::<PyStrip>()?;
m.add_class::<PyMetaspaceDec>()?;
m.add_class::<PyBPEDecoder>()?;
m.add_class::<PyCTCDecoder>()?;

View File

@@ -13,6 +13,8 @@ from tokenizers.decoders import (
WordPiece,
ByteFallback,
Replace,
Strip,
Fuse,
)
@@ -94,6 +96,30 @@ class TestByteFallback:
assert decoder.decode(["<0xE5>", "<0x8f>", "<0xab>", "a"]) == "叫a"
class TestFuse:
def test_instantiate(self):
assert Fuse() is not None
assert isinstance(Fuse(), Decoder)
assert isinstance(Fuse(), Fuse)
assert isinstance(pickle.loads(pickle.dumps(Fuse())), Fuse)
def test_decoding(self):
decoder = Fuse()
assert decoder.decode(["My", " na", "me"]) == "My name"
class TestStrip:
def test_instantiate(self):
assert Strip(left=0, right=0) is not None
assert isinstance(Strip(left=0, right=0), Decoder)
assert isinstance(Strip(left=0, right=0), Strip)
assert isinstance(pickle.loads(pickle.dumps(Strip(left=0, right=0))), Strip)
def test_decoding(self):
decoder = Strip(left=1, right=0)
assert decoder.decode(["My", " na", "me"]) == "ynae"
class TestMetaspace:
def test_instantiate(self):
assert Metaspace() is not None