mirror of
https://github.com/mii443/tokenizers.git
synced 2025-12-04 19:58:21 +00:00
Changing Decoder trait to be more composable. (#938)
* Changing `Decoder` trait to be more composable. Fix #872 * Fixing Python side. * Fixing test. * Updating cleanup signature, removing turbofish.
This commit is contained in:
@@ -51,7 +51,7 @@ impl PyDecoder {
|
||||
}
|
||||
|
||||
impl Decoder for PyDecoder {
|
||||
fn decode(&self, tokens: Vec<String>) -> tk::Result<String> {
|
||||
fn decode(&self, tokens: Vec<String>) -> tk::Result<Vec<String>> {
|
||||
self.decoder.decode(tokens)
|
||||
}
|
||||
}
|
||||
@@ -98,7 +98,7 @@ impl PyDecoder {
|
||||
/// Returns:
|
||||
/// :obj:`str`: The decoded string
|
||||
#[text_signature = "(self, tokens)"]
|
||||
fn decode(&self, tokens: Vec<String>) -> PyResult<String> {
|
||||
fn decode(&self, tokens: Vec<String>) -> PyResult<Vec<String>> {
|
||||
ToPyResult(self.decoder.decode(tokens)).into()
|
||||
}
|
||||
}
|
||||
@@ -337,12 +337,12 @@ impl CustomDecoder {
|
||||
}
|
||||
|
||||
impl Decoder for CustomDecoder {
|
||||
fn decode(&self, tokens: Vec<String>) -> tk::Result<String> {
|
||||
fn decode(&self, tokens: Vec<String>) -> tk::Result<Vec<String>> {
|
||||
Python::with_gil(|py| {
|
||||
let decoded = self
|
||||
.inner
|
||||
.call_method(py, "decode", (tokens,), None)?
|
||||
.extract::<String>(py)?;
|
||||
.extract(py)?;
|
||||
Ok(decoded)
|
||||
})
|
||||
}
|
||||
@@ -396,7 +396,7 @@ where
|
||||
}
|
||||
|
||||
impl Decoder for PyDecoderWrapper {
|
||||
fn decode(&self, tokens: Vec<String>) -> tk::Result<String> {
|
||||
fn decode(&self, tokens: Vec<String>) -> tk::Result<Vec<String>> {
|
||||
match self {
|
||||
PyDecoderWrapper::Wrapped(inner) => inner.read().unwrap().decode(tokens),
|
||||
PyDecoderWrapper::Custom(inner) => inner.read().unwrap().decode(tokens),
|
||||
|
||||
@@ -14,7 +14,7 @@ class TestByteLevel:
|
||||
|
||||
def test_decoding(self):
|
||||
decoder = ByteLevel()
|
||||
assert decoder.decode(["My", "Ġname", "Ġis", "ĠJohn"]) == "My name is John"
|
||||
assert decoder.decode(["My", "Ġname", "Ġis", "ĠJohn"]) == ["My name is John"]
|
||||
|
||||
def test_manual_reload(self):
|
||||
byte_level = ByteLevel()
|
||||
@@ -34,11 +34,25 @@ class TestWordPiece:
|
||||
|
||||
def test_decoding(self):
|
||||
decoder = WordPiece()
|
||||
assert decoder.decode(["My", "na", "##me", "is", "Jo", "##hn"]) == "My name is John"
|
||||
assert decoder.decode(["I", "'m", "Jo", "##hn"]) == "I'm John"
|
||||
assert decoder.decode(["My", "na", "##me", "is", "Jo", "##hn"]) == [
|
||||
"My",
|
||||
" na",
|
||||
"me",
|
||||
" is",
|
||||
" Jo",
|
||||
"hn",
|
||||
]
|
||||
assert decoder.decode(["I", "'m", "Jo", "##hn"]) == ["I", "'m", " Jo", "hn"]
|
||||
decoder = WordPiece(prefix="__", cleanup=False)
|
||||
assert decoder.decode(["My", "na", "__me", "is", "Jo", "__hn"]) == "My name is John"
|
||||
assert decoder.decode(["I", "'m", "Jo", "__hn"]) == "I 'm John"
|
||||
assert decoder.decode(["My", "na", "__me", "is", "Jo", "__hn"]) == [
|
||||
"My",
|
||||
" na",
|
||||
"me",
|
||||
" is",
|
||||
" Jo",
|
||||
"hn",
|
||||
]
|
||||
assert decoder.decode(["I", "'m", "Jo", "__hn"]) == ["I", " 'm", " Jo", "hn"]
|
||||
|
||||
def test_can_modify(self):
|
||||
decoder = WordPiece(prefix="$$", cleanup=False)
|
||||
@@ -66,9 +80,9 @@ class TestMetaspace:
|
||||
|
||||
def test_decoding(self):
|
||||
decoder = Metaspace()
|
||||
assert decoder.decode(["▁My", "▁name", "▁is", "▁John"]) == "My name is John"
|
||||
assert decoder.decode(["▁My", "▁name", "▁is", "▁John"]) == ["My", " name", " is", " John"]
|
||||
decoder = Metaspace(replacement="-", add_prefix_space=False)
|
||||
assert decoder.decode(["-My", "-name", "-is", "-John"]) == " My name is John"
|
||||
assert decoder.decode(["-My", "-name", "-is", "-John"]) == [" My", " name", " is", " John"]
|
||||
|
||||
def test_can_modify(self):
|
||||
decoder = Metaspace(replacement="*", add_prefix_space=False)
|
||||
@@ -93,12 +107,23 @@ class TestBPEDecoder:
|
||||
|
||||
def test_decoding(self):
|
||||
decoder = BPEDecoder()
|
||||
assert (
|
||||
decoder.decode(["My</w>", "na", "me</w>", "is</w>", "Jo", "hn</w>"])
|
||||
== "My name is John"
|
||||
)
|
||||
assert decoder.decode(["My</w>", "na", "me</w>", "is</w>", "Jo", "hn</w>"]) == [
|
||||
"My ",
|
||||
"na",
|
||||
"me ",
|
||||
"is ",
|
||||
"Jo",
|
||||
"hn",
|
||||
]
|
||||
decoder = BPEDecoder(suffix="_")
|
||||
assert decoder.decode(["My_", "na", "me_", "is_", "Jo", "hn_"]) == "My name is John"
|
||||
assert decoder.decode(["My_", "na", "me_", "is_", "Jo", "hn_"]) == [
|
||||
"My ",
|
||||
"na",
|
||||
"me ",
|
||||
"is ",
|
||||
"Jo",
|
||||
"hn",
|
||||
]
|
||||
|
||||
def test_can_modify(self):
|
||||
decoder = BPEDecoder(suffix="123")
|
||||
@@ -120,19 +145,13 @@ class TestCTCDecoder:
|
||||
|
||||
def test_decoding(self):
|
||||
decoder = CTC()
|
||||
assert (
|
||||
decoder.decode(
|
||||
["<pad>", "<pad>", "h", "e", "e", "l", "l", "<pad>", "l", "o", "o", "o", "<pad>"]
|
||||
)
|
||||
== "hello"
|
||||
)
|
||||
assert decoder.decode(
|
||||
["<pad>", "<pad>", "h", "e", "e", "l", "l", "<pad>", "l", "o", "o", "o", "<pad>"]
|
||||
) == ["h", "e", "l", "l", "o"]
|
||||
decoder = CTC(pad_token="[PAD]")
|
||||
assert (
|
||||
decoder.decode(
|
||||
["[PAD]", "[PAD]", "h", "e", "e", "l", "l", "[PAD]", "l", "o", "o", "o", "[PAD]"]
|
||||
)
|
||||
== "hello"
|
||||
)
|
||||
assert decoder.decode(
|
||||
["[PAD]", "[PAD]", "h", "e", "e", "l", "l", "[PAD]", "l", "o", "o", "o", "[PAD]"]
|
||||
) == ["h", "e", "l", "l", "o"]
|
||||
|
||||
def test_can_modify(self):
|
||||
decoder = CTC(pad_token="[PAD]")
|
||||
|
||||
Reference in New Issue
Block a user