Add CTC Decoder for Wave2Vec models (#693)

* Rust - add a CTCDecoder as a seperate mod

* Adding bindings to Node + Python.

* Clippy update.

* Stub.

* Fixing roberta.json URLs.

* Moving test files to hf.co.

* Update cargo check and clippy to 1.52.

* Inner ':' actually is used for domains in sphinx.

Making `domain` work correctly was just too much work so I went the easy
way and have global roles for the custom rust extension.

* Update struct naming and docs

* Update changelog

Co-authored-by: Thomaub <github.thomaub@gmail.com>
Co-authored-by: Anthony MOI <m.anthony.moi@gmail.com>
This commit is contained in:
Nicolas Patry
2021-05-20 15:30:09 +02:00
committed by GitHub
parent e999a7b5f9
commit 2e2e7558f7
24 changed files with 353 additions and 50 deletions

View File

@@ -2,7 +2,7 @@ import pytest
import pickle
import json
from tokenizers.decoders import Decoder, ByteLevel, WordPiece, Metaspace, BPEDecoder
from tokenizers.decoders import Decoder, ByteLevel, WordPiece, Metaspace, BPEDecoder, CTC
class TestByteLevel:
@@ -108,3 +108,45 @@ class TestBPEDecoder:
# Modify these
decoder.suffix = "</w>"
assert decoder.suffix == "</w>"
class TestCTCDecoder:
def test_instantiate(self):
assert CTC() is not None
assert CTC(pad_token="[PAD]") is not None
assert isinstance(CTC(), Decoder)
assert isinstance(CTC(), CTC)
assert isinstance(pickle.loads(pickle.dumps(CTC())), CTC)
def test_decoding(self):
decoder = CTC()
assert (
decoder.decode(
["<pad>", "<pad>", "h", "e", "e", "l", "l", "<pad>", "l", "o", "o", "o", "<pad>"]
)
== "hello"
)
decoder = CTC(pad_token="[PAD]")
assert (
decoder.decode(
["[PAD]", "[PAD]", "h", "e", "e", "l", "l", "[PAD]", "l", "o", "o", "o", "[PAD]"]
)
== "hello"
)
def test_can_modify(self):
decoder = CTC(pad_token="[PAD]")
assert decoder.pad_token == "[PAD]"
assert decoder.word_delimiter_token == "|"
assert decoder.cleanup == True
# Modify these
decoder.pad_token = "{pad}"
assert decoder.pad_token == "{pad}"
decoder.word_delimiter_token = "_"
assert decoder.word_delimiter_token == "_"
decoder.cleanup = False
assert decoder.cleanup == False