Changing Decoder trait to be more composable. (#938)

* Changing `Decoder` trait to be more composable.

Fix #872

* Fixing Python side.

* Fixing test.

* Updating cleanup signature, removing turbofish.
This commit is contained in:
Nicolas Patry
2022-03-17 10:32:09 +01:00
committed by GitHub
parent 1f1f86dd32
commit cdabef14c4
11 changed files with 147 additions and 80 deletions

View File

@@ -51,7 +51,7 @@ impl PyDecoder {
}
impl Decoder for PyDecoder {
fn decode(&self, tokens: Vec<String>) -> tk::Result<String> {
fn decode(&self, tokens: Vec<String>) -> tk::Result<Vec<String>> {
self.decoder.decode(tokens)
}
}
@@ -98,7 +98,7 @@ impl PyDecoder {
/// Returns:
/// :obj:`str`: The decoded string
#[text_signature = "(self, tokens)"]
fn decode(&self, tokens: Vec<String>) -> PyResult<String> {
fn decode(&self, tokens: Vec<String>) -> PyResult<Vec<String>> {
ToPyResult(self.decoder.decode(tokens)).into()
}
}
@@ -337,12 +337,12 @@ impl CustomDecoder {
}
impl Decoder for CustomDecoder {
fn decode(&self, tokens: Vec<String>) -> tk::Result<String> {
fn decode(&self, tokens: Vec<String>) -> tk::Result<Vec<String>> {
Python::with_gil(|py| {
let decoded = self
.inner
.call_method(py, "decode", (tokens,), None)?
.extract::<String>(py)?;
.extract(py)?;
Ok(decoded)
})
}
@@ -396,7 +396,7 @@ where
}
impl Decoder for PyDecoderWrapper {
fn decode(&self, tokens: Vec<String>) -> tk::Result<String> {
fn decode(&self, tokens: Vec<String>) -> tk::Result<Vec<String>> {
match self {
PyDecoderWrapper::Wrapped(inner) => inner.read().unwrap().decode(tokens),
PyDecoderWrapper::Custom(inner) => inner.read().unwrap().decode(tokens),

View File

@@ -14,7 +14,7 @@ class TestByteLevel:
def test_decoding(self):
decoder = ByteLevel()
assert decoder.decode(["My", "Ġname", "Ġis", "ĠJohn"]) == "My name is John"
assert decoder.decode(["My", "Ġname", "Ġis", "ĠJohn"]) == ["My name is John"]
def test_manual_reload(self):
byte_level = ByteLevel()
@@ -34,11 +34,25 @@ class TestWordPiece:
def test_decoding(self):
decoder = WordPiece()
assert decoder.decode(["My", "na", "##me", "is", "Jo", "##hn"]) == "My name is John"
assert decoder.decode(["I", "'m", "Jo", "##hn"]) == "I'm John"
assert decoder.decode(["My", "na", "##me", "is", "Jo", "##hn"]) == [
"My",
" na",
"me",
" is",
" Jo",
"hn",
]
assert decoder.decode(["I", "'m", "Jo", "##hn"]) == ["I", "'m", " Jo", "hn"]
decoder = WordPiece(prefix="__", cleanup=False)
assert decoder.decode(["My", "na", "__me", "is", "Jo", "__hn"]) == "My name is John"
assert decoder.decode(["I", "'m", "Jo", "__hn"]) == "I 'm John"
assert decoder.decode(["My", "na", "__me", "is", "Jo", "__hn"]) == [
"My",
" na",
"me",
" is",
" Jo",
"hn",
]
assert decoder.decode(["I", "'m", "Jo", "__hn"]) == ["I", " 'm", " Jo", "hn"]
def test_can_modify(self):
decoder = WordPiece(prefix="$$", cleanup=False)
@@ -66,9 +80,9 @@ class TestMetaspace:
def test_decoding(self):
decoder = Metaspace()
assert decoder.decode(["▁My", "▁name", "▁is", "▁John"]) == "My name is John"
assert decoder.decode(["▁My", "▁name", "▁is", "▁John"]) == ["My", " name", " is", " John"]
decoder = Metaspace(replacement="-", add_prefix_space=False)
assert decoder.decode(["-My", "-name", "-is", "-John"]) == " My name is John"
assert decoder.decode(["-My", "-name", "-is", "-John"]) == [" My", " name", " is", " John"]
def test_can_modify(self):
decoder = Metaspace(replacement="*", add_prefix_space=False)
@@ -93,12 +107,23 @@ class TestBPEDecoder:
def test_decoding(self):
decoder = BPEDecoder()
assert (
decoder.decode(["My</w>", "na", "me</w>", "is</w>", "Jo", "hn</w>"])
== "My name is John"
)
assert decoder.decode(["My</w>", "na", "me</w>", "is</w>", "Jo", "hn</w>"]) == [
"My ",
"na",
"me ",
"is ",
"Jo",
"hn",
]
decoder = BPEDecoder(suffix="_")
assert decoder.decode(["My_", "na", "me_", "is_", "Jo", "hn_"]) == "My name is John"
assert decoder.decode(["My_", "na", "me_", "is_", "Jo", "hn_"]) == [
"My ",
"na",
"me ",
"is ",
"Jo",
"hn",
]
def test_can_modify(self):
decoder = BPEDecoder(suffix="123")
@@ -120,19 +145,13 @@ class TestCTCDecoder:
def test_decoding(self):
decoder = CTC()
assert (
decoder.decode(
["<pad>", "<pad>", "h", "e", "e", "l", "l", "<pad>", "l", "o", "o", "o", "<pad>"]
)
== "hello"
)
assert decoder.decode(
["<pad>", "<pad>", "h", "e", "e", "l", "l", "<pad>", "l", "o", "o", "o", "<pad>"]
) == ["h", "e", "l", "l", "o"]
decoder = CTC(pad_token="[PAD]")
assert (
decoder.decode(
["[PAD]", "[PAD]", "h", "e", "e", "l", "l", "[PAD]", "l", "o", "o", "o", "[PAD]"]
)
== "hello"
)
assert decoder.decode(
["[PAD]", "[PAD]", "h", "e", "e", "l", "l", "[PAD]", "l", "o", "o", "o", "[PAD]"]
) == ["h", "e", "l", "l", "o"]
def test_can_modify(self):
decoder = CTC(pad_token="[PAD]")