Changing Decoder trait to be more composable. (#938)

* Changing `Decoder` trait to be more composable.

Fix #872

* Fixing Python side.

* Fixing test.

* Updating cleanup signature, removing turbofish.
This commit is contained in:
Nicolas Patry
2022-03-17 10:32:09 +01:00
committed by GitHub
parent 1f1f86dd32
commit cdabef14c4
11 changed files with 147 additions and 80 deletions

View File

@ -14,7 +14,7 @@ class TestByteLevel:
def test_decoding(self):
decoder = ByteLevel()
assert decoder.decode(["My", "Ġname", "Ġis", "ĠJohn"]) == "My name is John"
assert decoder.decode(["My", "Ġname", "Ġis", "ĠJohn"]) == ["My name is John"]
def test_manual_reload(self):
byte_level = ByteLevel()
@ -34,11 +34,25 @@ class TestWordPiece:
def test_decoding(self):
decoder = WordPiece()
assert decoder.decode(["My", "na", "##me", "is", "Jo", "##hn"]) == "My name is John"
assert decoder.decode(["I", "'m", "Jo", "##hn"]) == "I'm John"
assert decoder.decode(["My", "na", "##me", "is", "Jo", "##hn"]) == [
"My",
" na",
"me",
" is",
" Jo",
"hn",
]
assert decoder.decode(["I", "'m", "Jo", "##hn"]) == ["I", "'m", " Jo", "hn"]
decoder = WordPiece(prefix="__", cleanup=False)
assert decoder.decode(["My", "na", "__me", "is", "Jo", "__hn"]) == "My name is John"
assert decoder.decode(["I", "'m", "Jo", "__hn"]) == "I 'm John"
assert decoder.decode(["My", "na", "__me", "is", "Jo", "__hn"]) == [
"My",
" na",
"me",
" is",
" Jo",
"hn",
]
assert decoder.decode(["I", "'m", "Jo", "__hn"]) == ["I", " 'm", " Jo", "hn"]
def test_can_modify(self):
decoder = WordPiece(prefix="$$", cleanup=False)
@ -66,9 +80,9 @@ class TestMetaspace:
def test_decoding(self):
decoder = Metaspace()
assert decoder.decode(["▁My", "▁name", "▁is", "▁John"]) == "My name is John"
assert decoder.decode(["▁My", "▁name", "▁is", "▁John"]) == ["My", " name", " is", " John"]
decoder = Metaspace(replacement="-", add_prefix_space=False)
assert decoder.decode(["-My", "-name", "-is", "-John"]) == " My name is John"
assert decoder.decode(["-My", "-name", "-is", "-John"]) == [" My", " name", " is", " John"]
def test_can_modify(self):
decoder = Metaspace(replacement="*", add_prefix_space=False)
@ -93,12 +107,23 @@ class TestBPEDecoder:
def test_decoding(self):
decoder = BPEDecoder()
assert (
decoder.decode(["My</w>", "na", "me</w>", "is</w>", "Jo", "hn</w>"])
== "My name is John"
)
assert decoder.decode(["My</w>", "na", "me</w>", "is</w>", "Jo", "hn</w>"]) == [
"My ",
"na",
"me ",
"is ",
"Jo",
"hn",
]
decoder = BPEDecoder(suffix="_")
assert decoder.decode(["My_", "na", "me_", "is_", "Jo", "hn_"]) == "My name is John"
assert decoder.decode(["My_", "na", "me_", "is_", "Jo", "hn_"]) == [
"My ",
"na",
"me ",
"is ",
"Jo",
"hn",
]
def test_can_modify(self):
decoder = BPEDecoder(suffix="123")
@ -120,19 +145,13 @@ class TestCTCDecoder:
def test_decoding(self):
decoder = CTC()
assert (
decoder.decode(
["<pad>", "<pad>", "h", "e", "e", "l", "l", "<pad>", "l", "o", "o", "o", "<pad>"]
)
== "hello"
)
assert decoder.decode(
["<pad>", "<pad>", "h", "e", "e", "l", "l", "<pad>", "l", "o", "o", "o", "<pad>"]
) == ["h", "e", "l", "l", "o"]
decoder = CTC(pad_token="[PAD]")
assert (
decoder.decode(
["[PAD]", "[PAD]", "h", "e", "e", "l", "l", "[PAD]", "l", "o", "o", "o", "[PAD]"]
)
== "hello"
)
assert decoder.decode(
["[PAD]", "[PAD]", "h", "e", "e", "l", "l", "[PAD]", "l", "o", "o", "o", "[PAD]"]
) == ["h", "e", "l", "l", "o"]
def test_can_modify(self):
decoder = CTC(pad_token="[PAD]")