mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Add CTC Decoder for Wave2Vec models (#693)
* Rust - add a CTCDecoder as a seperate mod * Adding bindings to Node + Python. * Clippy update. * Stub. * Fixing roberta.json URLs. * Moving test files to hf.co. * Update cargo check and clippy to 1.52. * Inner ':' actually is used for domains in sphinx. Making `domain` work correctly was just too much work so I went the easy way and have global roles for the custom rust extension. * Update struct naming and docs * Update changelog Co-authored-by: Thomaub <github.thomaub@gmail.com> Co-authored-by: Anthony MOI <m.anthony.moi@gmail.com>
This commit is contained in:
@ -27,7 +27,7 @@ $(DATA_DIR)/small.txt : $(DATA_DIR)/big.txt
|
||||
|
||||
$(DATA_DIR)/roberta.json :
|
||||
$(dir_guard)
|
||||
wget https://storage.googleapis.com/tokenizers/roberta.json -O $@
|
||||
wget https://huggingface.co/roberta-large/raw/main/tokenizer.json -O $@
|
||||
|
||||
$(DATA_DIR)/tokenizer-wiki.json :
|
||||
$(dir_guard)
|
||||
|
13
bindings/node/lib/bindings/decoders.d.ts
vendored
13
bindings/node/lib/bindings/decoders.d.ts
vendored
@ -36,3 +36,16 @@ export function metaspaceDecoder(replacement?: string, addPrefixSpace?: boolean)
|
||||
* This suffix will be replaced by whitespaces during the decoding
|
||||
*/
|
||||
export function bpeDecoder(suffix?: string): Decoder;
|
||||
|
||||
/**
|
||||
* Instantiate a new CTC Decoder
|
||||
* @param [pad_token='pad'] The pad token used by CTC to delimit a new token.
|
||||
* @param [word_delimiter_token='|'] The word delimiter token. It will be replaced by a space
|
||||
* @param [cleanup=true] Whether to cleanup some tokenization artifacts.
|
||||
* Mainly spaces before punctuation, and some abbreviated english forms.
|
||||
*/
|
||||
export function ctcDecoder(
|
||||
pad_token?: string,
|
||||
word_delimiter_token?: string,
|
||||
cleanup?: boolean
|
||||
): Decoder;
|
||||
|
@ -5,4 +5,5 @@ module.exports = {
|
||||
wordPieceDecoder: native.decoders_WordPiece,
|
||||
metaspaceDecoder: native.decoders_Metaspace,
|
||||
bpeDecoder: native.decoders_BPEDecoder,
|
||||
ctcDecoder: native.decoders_CTC,
|
||||
};
|
||||
|
@ -1,4 +1,4 @@
|
||||
import { bpeDecoder, metaspaceDecoder, wordPieceDecoder } from "./decoders";
|
||||
import { bpeDecoder, ctcDecoder, metaspaceDecoder, wordPieceDecoder } from "./decoders";
|
||||
|
||||
describe("wordPieceDecoder", () => {
|
||||
it("accepts `undefined` as first parameter", () => {
|
||||
@ -31,3 +31,14 @@ describe("bpeDecoder", () => {
|
||||
expect(bpeDecoder(undefined)).toBeDefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe("ctcDecoder", () => {
|
||||
it("accepts `undefined` as parameter", () => {
|
||||
expect(ctcDecoder(undefined)).toBeDefined();
|
||||
});
|
||||
it("encodes correctly", () => {
|
||||
expect(
|
||||
ctcDecoder().decode(["<pad>", "h", "h", "e", "e", "l", "l", "<pad>", "l", "l", "o"])
|
||||
).toEqual("hello");
|
||||
});
|
||||
});
|
||||
|
@ -97,11 +97,30 @@ fn bpe_decoder(mut cx: FunctionContext) -> JsResult<JsDecoder> {
|
||||
Ok(decoder)
|
||||
}
|
||||
|
||||
/// ctc_decoder(pad_token: String = "<pad>", word_delimiter_token: String = "|", cleanup = true)
|
||||
fn ctc_decoder(mut cx: FunctionContext) -> JsResult<JsDecoder> {
|
||||
let pad_token = cx
|
||||
.extract_opt::<String>(0)?
|
||||
.unwrap_or_else(|| String::from("<pad>"));
|
||||
let word_delimiter_token = cx
|
||||
.extract_opt::<String>(1)?
|
||||
.unwrap_or_else(|| String::from("|"));
|
||||
let cleanup = cx.extract_opt::<bool>(2)?.unwrap_or(true);
|
||||
|
||||
let mut decoder = JsDecoder::new::<_, JsDecoder, _>(&mut cx, vec![])?;
|
||||
let guard = cx.lock();
|
||||
decoder.borrow_mut(&guard).decoder = Some(Arc::new(
|
||||
tk::decoders::ctc::CTC::new(pad_token, word_delimiter_token, cleanup).into(),
|
||||
));
|
||||
Ok(decoder)
|
||||
}
|
||||
|
||||
/// Register everything here
|
||||
pub fn register(m: &mut ModuleContext, prefix: &str) -> NeonResult<()> {
|
||||
m.export_function(&format!("{}_ByteLevel", prefix), byte_level)?;
|
||||
m.export_function(&format!("{}_WordPiece", prefix), wordpiece)?;
|
||||
m.export_function(&format!("{}_Metaspace", prefix), metaspace)?;
|
||||
m.export_function(&format!("{}_BPEDecoder", prefix), bpe_decoder)?;
|
||||
m.export_function(&format!("{}_CTC", prefix), ctc_decoder)?;
|
||||
Ok(())
|
||||
}
|
||||
|
@ -4,6 +4,11 @@ All notable changes to this project will be documented in this file.
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
### Added
|
||||
- [#693]: Add a CTC Decoder for Wave2Vec models
|
||||
|
||||
## [0.10.2]
|
||||
|
||||
### Fixed
|
||||
@ -309,6 +314,7 @@ delimiter (Works like `.split(delimiter)`)
|
||||
- Fix a bug that was causing crashes in Python 3.5
|
||||
|
||||
|
||||
[#694]: https://github.com/huggingface/tokenizers/pull/694
|
||||
[#674]: https://github.com/huggingface/tokenizers/pull/674
|
||||
[#656]: https://github.com/huggingface/tokenizers/pull/656
|
||||
[#652]: https://github.com/huggingface/tokenizers/pull/652
|
||||
|
@ -30,4 +30,4 @@ $(DATA_DIR)/small.txt : $(DATA_DIR)/big.txt
|
||||
|
||||
$(DATA_DIR)/roberta.json :
|
||||
$(dir_guard)
|
||||
wget https://storage.googleapis.com/tokenizers/roberta.json -O $@
|
||||
wget https://huggingface.co/roberta-large/raw/main/tokenizer.json -O $@
|
||||
|
@ -5,3 +5,4 @@ ByteLevel = decoders.ByteLevel
|
||||
WordPiece = decoders.WordPiece
|
||||
Metaspace = decoders.Metaspace
|
||||
BPEDecoder = decoders.BPEDecoder
|
||||
CTC = decoders.CTC
|
||||
|
@ -68,6 +68,35 @@ class ByteLevel(Decoder):
|
||||
"""
|
||||
pass
|
||||
|
||||
class CTC(Decoder):
|
||||
"""
|
||||
CTC Decoder
|
||||
|
||||
Args:
|
||||
pad_token (:obj:`str`, `optional`, defaults to :obj:`<pad>`):
|
||||
The pad token used by CTC to delimit a new token.
|
||||
word_delimiter_token (:obj:`str`, `optional`, defaults to :obj:`|`):
|
||||
The word delimiter token. It will be replaced by a <space>
|
||||
cleanup (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether to cleanup some tokenization artifacts.
|
||||
Mainly spaces before punctuation, and some abbreviated english forms.
|
||||
"""
|
||||
|
||||
def __init__(self, pad_token="<pad>", word_delimiter_token="|", cleanup=True):
|
||||
pass
|
||||
def decode(self, tokens):
|
||||
"""
|
||||
Decode the given list of tokens to a final string
|
||||
|
||||
Args:
|
||||
tokens (:obj:`List[str]`):
|
||||
The list of tokens to decode
|
||||
|
||||
Returns:
|
||||
:obj:`str`: The decoded string
|
||||
"""
|
||||
pass
|
||||
|
||||
class Metaspace(Decoder):
|
||||
"""
|
||||
Metaspace Decoder
|
||||
|
@ -8,6 +8,7 @@ use serde::de::Error;
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
use tk::decoders::bpe::BPEDecoder;
|
||||
use tk::decoders::byte_level::ByteLevel;
|
||||
use tk::decoders::ctc::CTC;
|
||||
use tk::decoders::metaspace::Metaspace;
|
||||
use tk::decoders::wordpiece::WordPiece;
|
||||
use tk::decoders::DecoderWrapper;
|
||||
@ -43,6 +44,7 @@ impl PyDecoder {
|
||||
DecoderWrapper::WordPiece(_) => Py::new(py, (PyWordPieceDec {}, base))?.into_py(py),
|
||||
DecoderWrapper::ByteLevel(_) => Py::new(py, (PyByteLevelDec {}, base))?.into_py(py),
|
||||
DecoderWrapper::BPE(_) => Py::new(py, (PyBPEDecoder {}, base))?.into_py(py),
|
||||
DecoderWrapper::CTC(_) => Py::new(py, (PyCTCDecoder {}, base))?.into_py(py),
|
||||
},
|
||||
})
|
||||
}
|
||||
@ -264,6 +266,65 @@ impl PyBPEDecoder {
|
||||
}
|
||||
}
|
||||
|
||||
/// CTC Decoder
|
||||
///
|
||||
/// Args:
|
||||
/// pad_token (:obj:`str`, `optional`, defaults to :obj:`<pad>`):
|
||||
/// The pad token used by CTC to delimit a new token.
|
||||
/// word_delimiter_token (:obj:`str`, `optional`, defaults to :obj:`|`):
|
||||
/// The word delimiter token. It will be replaced by a <space>
|
||||
/// cleanup (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
/// Whether to cleanup some tokenization artifacts.
|
||||
/// Mainly spaces before punctuation, and some abbreviated english forms.
|
||||
#[pyclass(extends=PyDecoder, module = "tokenizers.decoders", name=CTC)]
|
||||
#[text_signature = "(self, pad_token=\"<pad>\", word_delimiter_token=\"|\", cleanup=True)"]
|
||||
pub struct PyCTCDecoder {}
|
||||
#[pymethods]
|
||||
impl PyCTCDecoder {
|
||||
#[getter]
|
||||
fn get_pad_token(self_: PyRef<Self>) -> String {
|
||||
getter!(self_, CTC, pad_token.clone())
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_pad_token(self_: PyRef<Self>, pad_token: String) {
|
||||
setter!(self_, CTC, pad_token, pad_token);
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_word_delimiter_token(self_: PyRef<Self>) -> String {
|
||||
getter!(self_, CTC, word_delimiter_token.clone())
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_word_delimiter_token(self_: PyRef<Self>, word_delimiter_token: String) {
|
||||
setter!(self_, CTC, word_delimiter_token, word_delimiter_token);
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_cleanup(self_: PyRef<Self>) -> bool {
|
||||
getter!(self_, CTC, cleanup)
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_cleanup(self_: PyRef<Self>, cleanup: bool) {
|
||||
setter!(self_, CTC, cleanup, cleanup);
|
||||
}
|
||||
|
||||
#[new]
|
||||
#[args(
|
||||
pad_token = "String::from(\"<pad>\")",
|
||||
word_delimiter_token = "String::from(\"|\")",
|
||||
cleanup = "true"
|
||||
)]
|
||||
fn new(pad_token: String, word_delimiter_token: String, cleanup: bool) -> (Self, PyDecoder) {
|
||||
(
|
||||
PyCTCDecoder {},
|
||||
CTC::new(pad_token, word_delimiter_token, cleanup).into(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct CustomDecoder {
|
||||
inner: PyObject,
|
||||
|
@ -87,6 +87,7 @@ fn decoders(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||
m.add_class::<decoders::PyWordPieceDec>()?;
|
||||
m.add_class::<decoders::PyMetaspaceDec>()?;
|
||||
m.add_class::<decoders::PyBPEDecoder>()?;
|
||||
m.add_class::<decoders::PyCTCDecoder>()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -15,7 +15,7 @@ where
|
||||
I: Iterator,
|
||||
{
|
||||
pub fn new(iter: I, length: Option<usize>) -> Self {
|
||||
Self { iter, length }
|
||||
Self { length, iter }
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2,7 +2,7 @@ import pytest
|
||||
import pickle
|
||||
import json
|
||||
|
||||
from tokenizers.decoders import Decoder, ByteLevel, WordPiece, Metaspace, BPEDecoder
|
||||
from tokenizers.decoders import Decoder, ByteLevel, WordPiece, Metaspace, BPEDecoder, CTC
|
||||
|
||||
|
||||
class TestByteLevel:
|
||||
@ -108,3 +108,45 @@ class TestBPEDecoder:
|
||||
# Modify these
|
||||
decoder.suffix = "</w>"
|
||||
assert decoder.suffix == "</w>"
|
||||
|
||||
|
||||
class TestCTCDecoder:
|
||||
def test_instantiate(self):
|
||||
assert CTC() is not None
|
||||
assert CTC(pad_token="[PAD]") is not None
|
||||
assert isinstance(CTC(), Decoder)
|
||||
assert isinstance(CTC(), CTC)
|
||||
assert isinstance(pickle.loads(pickle.dumps(CTC())), CTC)
|
||||
|
||||
def test_decoding(self):
|
||||
decoder = CTC()
|
||||
assert (
|
||||
decoder.decode(
|
||||
["<pad>", "<pad>", "h", "e", "e", "l", "l", "<pad>", "l", "o", "o", "o", "<pad>"]
|
||||
)
|
||||
== "hello"
|
||||
)
|
||||
decoder = CTC(pad_token="[PAD]")
|
||||
assert (
|
||||
decoder.decode(
|
||||
["[PAD]", "[PAD]", "h", "e", "e", "l", "l", "[PAD]", "l", "o", "o", "o", "[PAD]"]
|
||||
)
|
||||
== "hello"
|
||||
)
|
||||
|
||||
def test_can_modify(self):
|
||||
decoder = CTC(pad_token="[PAD]")
|
||||
|
||||
assert decoder.pad_token == "[PAD]"
|
||||
assert decoder.word_delimiter_token == "|"
|
||||
assert decoder.cleanup == True
|
||||
|
||||
# Modify these
|
||||
decoder.pad_token = "{pad}"
|
||||
assert decoder.pad_token == "{pad}"
|
||||
|
||||
decoder.word_delimiter_token = "_"
|
||||
assert decoder.word_delimiter_token == "_"
|
||||
|
||||
decoder.cleanup = False
|
||||
assert decoder.cleanup == False
|
||||
|
@ -10,7 +10,7 @@ logger = sphinx.util.logging.getLogger(__name__)
|
||||
|
||||
class RustRef:
|
||||
def __call__(self, name, rawtext, text, lineno, inliner, options={}, content=[]):
|
||||
doctype = name.split(":")[1]
|
||||
doctype = name.split("_")[1]
|
||||
parts = text.split("::")
|
||||
|
||||
if text.startswith("~"):
|
||||
@ -87,10 +87,10 @@ class RustRef:
|
||||
|
||||
|
||||
def setup(app):
|
||||
app.add_role("rust:struct", RustRef())
|
||||
app.add_role("rust:func", RustRef())
|
||||
app.add_role("rust:meth", RustRef())
|
||||
app.add_role("rust:trait", RustRef())
|
||||
app.add_role("rust_struct", RustRef())
|
||||
app.add_role("rust_func", RustRef())
|
||||
app.add_role("rust_meth", RustRef())
|
||||
app.add_role("rust_trait", RustRef())
|
||||
|
||||
return {
|
||||
"version": "0.1",
|
||||
|
@ -58,47 +58,47 @@
|
||||
classmethod
|
||||
static method
|
||||
Tokenizer
|
||||
:rust:struct:`~tokenizers::tokenizer::Tokenizer`
|
||||
:rust_struct:`~tokenizers::tokenizer::Tokenizer`
|
||||
Tokenizer.train
|
||||
:rust:meth:`~tokenizers::tokenizer::Tokenizer::train`
|
||||
:rust_meth:`~tokenizers::tokenizer::Tokenizer::train`
|
||||
Tokenizer.save
|
||||
:rust:meth:`~tokenizers::tokenizer::Tokenizer::save`
|
||||
:rust_meth:`~tokenizers::tokenizer::Tokenizer::save`
|
||||
Tokenizer.from_file
|
||||
:rust:meth:`~tokenizers::tokenizer::Tokenizer::from_file`
|
||||
:rust_meth:`~tokenizers::tokenizer::Tokenizer::from_file`
|
||||
Tokenizer.encode
|
||||
:rust:meth:`~tokenizers::tokenizer::Tokenizer::encode`
|
||||
:rust_meth:`~tokenizers::tokenizer::Tokenizer::encode`
|
||||
Tokenizer.encode_batch
|
||||
:rust:meth:`~tokenizers::tokenizer::Tokenizer::encode_batch`
|
||||
:rust_meth:`~tokenizers::tokenizer::Tokenizer::encode_batch`
|
||||
Tokenizer.decode
|
||||
:rust:meth:`~tokenizers::tokenizer::Tokenizer::decode`
|
||||
:rust_meth:`~tokenizers::tokenizer::Tokenizer::decode`
|
||||
Tokenizer.decode_batch
|
||||
:rust:meth:`~tokenizers::tokenizer::Tokenizer::decode_batch`
|
||||
:rust_meth:`~tokenizers::tokenizer::Tokenizer::decode_batch`
|
||||
Tokenizer.token_to_id
|
||||
:rust:meth:`~tokenizers::tokenizer::Tokenizer::token_to_id`
|
||||
:rust_meth:`~tokenizers::tokenizer::Tokenizer::token_to_id`
|
||||
Tokenizer.enable_padding
|
||||
:rust:meth:`~tokenizers::tokenizer::Tokenizer::enable_padding`
|
||||
:rust_meth:`~tokenizers::tokenizer::Tokenizer::enable_padding`
|
||||
Encoding
|
||||
:rust:struct:`~tokenizers::tokenizer::Encoding`
|
||||
:rust_struct:`~tokenizers::tokenizer::Encoding`
|
||||
TemplateProcessing
|
||||
:rust:struct:`~tokenizers::processors::template::TemplateProcessing`
|
||||
:rust_struct:`~tokenizers::processors::template::TemplateProcessing`
|
||||
Normalizer
|
||||
:rust:trait:`~tokenizers::tokenizer::Normalizer`
|
||||
:rust_trait:`~tokenizers::tokenizer::Normalizer`
|
||||
normalizers.Sequence
|
||||
:rust:struct:`~tokenizers::normalizers::utils::Sequence`
|
||||
:rust_struct:`~tokenizers::normalizers::utils::Sequence`
|
||||
pre_tokenizers.Whitespace
|
||||
:rust:struct:`~tokenizers::normalizers::whitespace::Whitespace`
|
||||
:rust_struct:`~tokenizers::normalizers::whitespace::Whitespace`
|
||||
PreTokenizer
|
||||
:rust:trait:`~tokenizers::tokenizer::PreTokenizer`
|
||||
:rust_trait:`~tokenizers::tokenizer::PreTokenizer`
|
||||
models.BPE
|
||||
:rust:struct:`~tokenizers::models::bpe::BPE`
|
||||
:rust_struct:`~tokenizers::models::bpe::BPE`
|
||||
models.Unigram
|
||||
:rust:struct:`~tokenizers::models::unigram::Unigram`
|
||||
:rust_struct:`~tokenizers::models::unigram::Unigram`
|
||||
models.WordLevel
|
||||
:rust:struct:`~tokenizers::models::wordlevel::WordLevel`
|
||||
:rust_struct:`~tokenizers::models::wordlevel::WordLevel`
|
||||
models.WordPiece
|
||||
:rust:struct:`~tokenizers::models::wordpiece::WordPiece`
|
||||
:rust_struct:`~tokenizers::models::wordpiece::WordPiece`
|
||||
Decoder
|
||||
:rust:trait:`~tokenizers::tokenizer::Decoder`
|
||||
:rust_trait:`~tokenizers::tokenizer::Decoder`
|
||||
|
||||
.. entities:: node
|
||||
|
||||
|
@ -44,7 +44,7 @@ Training the tokenizer
|
||||
.. entities:: rust
|
||||
|
||||
BpeTrainer
|
||||
:rust:struct:`~tokenizers::models::bpe::BpeTrainer`
|
||||
:rust_struct:`~tokenizers::models::bpe::BpeTrainer`
|
||||
vocab_size
|
||||
:obj:`vocab_size`
|
||||
min_frequency
|
||||
|
@ -55,7 +55,7 @@ $(DATA_DIR)/bert-% :
|
||||
|
||||
$(DATA_DIR)/unigram% :
|
||||
$(dir_guard)
|
||||
wget https://storage.googleapis.com/tokenizers/unigram$* -O $@
|
||||
wget https://huggingface.co/Narsil/small/raw/main/unigram$* -O $@
|
||||
|
||||
$(DATA_DIR)/albert-base-v1-tokenizer.json :
|
||||
$(dir_guard)
|
||||
@ -70,7 +70,7 @@ $(DATA_DIR)/small.txt : $(DATA_DIR)/big.txt
|
||||
|
||||
$(DATA_DIR)/roberta.json :
|
||||
$(dir_guard)
|
||||
wget https://storage.googleapis.com/tokenizers/roberta.json -O $@
|
||||
wget https://huggingface.co/Narsil/small/raw/main/roberta.json -O $@
|
||||
|
||||
$(DATA_DIR)/tokenizer-wiki.json :
|
||||
$(dir_guard)
|
||||
|
103
tokenizers/src/decoders/ctc.rs
Normal file
103
tokenizers/src/decoders/ctc.rs
Normal file
@ -0,0 +1,103 @@
|
||||
use crate::decoders::wordpiece;
|
||||
use crate::tokenizer::{Decoder, Result};
|
||||
|
||||
use itertools::Itertools;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
/// The CTC (Connectionist Temporal Classification) decoder takes care
|
||||
/// of sanitizing a list of inputs token.
|
||||
/// Due to some alignement problem the output of some models can come
|
||||
/// with duplicated token.
|
||||
#[serde(tag = "type")]
|
||||
#[non_exhaustive]
|
||||
pub struct CTC {
|
||||
/// The pad token used by CTC to delimit a new token.
|
||||
pub pad_token: String,
|
||||
/// The word delimiter token. It will be replaced by a <space>
|
||||
pub word_delimiter_token: String,
|
||||
/// Whether to cleanup some tokenization artifacts.
|
||||
/// Mainly spaces before punctuation, and some abbreviated english forms.
|
||||
pub cleanup: bool,
|
||||
}
|
||||
|
||||
impl CTC {
|
||||
pub fn new(pad_token: String, word_delimiter_token: String, cleanup: bool) -> Self {
|
||||
Self {
|
||||
pad_token,
|
||||
word_delimiter_token,
|
||||
cleanup,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for CTC {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
pad_token: "<pad>".to_string(),
|
||||
word_delimiter_token: "|".to_string(),
|
||||
cleanup: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Decoder for CTC {
|
||||
fn decode(&self, tokens: Vec<String>) -> Result<String> {
|
||||
let mut output = tokens
|
||||
.into_iter()
|
||||
.dedup()
|
||||
.join("")
|
||||
.replace(&self.pad_token, "");
|
||||
if self.cleanup {
|
||||
output = wordpiece::cleanup(output).replace(&self.word_delimiter_token, " ");
|
||||
}
|
||||
Ok(output)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
#[test]
|
||||
fn handmade_sample() {
|
||||
let ctc_decoder = CTC::default();
|
||||
let id_to_string_result = "<pad> <pad> h e e l l <pad> l o o o <pad>"
|
||||
.split(' ')
|
||||
.map(|s| s.to_string())
|
||||
.collect();
|
||||
assert_eq!(
|
||||
ctc_decoder.decode(id_to_string_result).unwrap(),
|
||||
"hello".to_string()
|
||||
);
|
||||
}
|
||||
#[test]
|
||||
fn handmade_with_delimiter_sample() {
|
||||
let ctc_decoder = CTC::default();
|
||||
let id_to_string_result = "<pad> <pad> h e e l l <pad> l o o o <pad> <pad> | <pad> w o o o r <pad> <pad> l l d <pad> <pad> <pad> <pad>"
|
||||
.split(' ')
|
||||
.map(|s| s.to_string())
|
||||
.collect();
|
||||
assert_eq!(
|
||||
ctc_decoder.decode(id_to_string_result).unwrap(),
|
||||
"hello world".to_string()
|
||||
);
|
||||
}
|
||||
#[test]
|
||||
fn librispeech_sample() {
|
||||
let ctc_decoder = CTC::default();
|
||||
let id_to_string_result = "<pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> A | | <pad> M <pad> <pad> <pad> <pad> A <pad> <pad> N <pad> <pad> <pad> | | | <pad> <pad> <pad> <pad> S <pad> <pad> <pad> A I <pad> D D | | T T <pad> O <pad> | | T H E E | | | <pad> U U <pad> N N <pad> I <pad> <pad> V <pad> <pad> <pad> E R R <pad> <pad> <pad> S E E | | <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> S S <pad> <pad> <pad> <pad> I <pad> R R <pad> <pad> | | | <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> I <pad> <pad> <pad> | <pad> <pad> <pad> E X <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> I <pad> S <pad> <pad> T <pad> <pad> | | <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>".split(' ').map(|s| s.to_string()).collect();
|
||||
assert_eq!(
|
||||
ctc_decoder.decode(id_to_string_result).unwrap(),
|
||||
"A MAN SAID TO THE UNIVERSE SIR I EXIST ".to_string()
|
||||
);
|
||||
}
|
||||
#[test]
|
||||
fn another_librispeech_sample() {
|
||||
let ctc_decoder = CTC::default();
|
||||
let id_to_string_result = "<pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> H <pad> I <pad> S S | | <pad> <pad> <pad> I N <pad> <pad> S <pad> T T <pad> <pad> A N C C T <pad> | | | | | <pad> <pad> <pad> <pad> P <pad> <pad> <pad> <pad> A <pad> <pad> N N N <pad> <pad> I <pad> C <pad> <pad> | | <pad> W <pad> <pad> A S <pad> | | <pad> <pad> <pad> F <pad> <pad> O L <pad> <pad> L L O O W E E D | | <pad> B <pad> <pad> <pad> Y <pad> | | | A | | <pad> S S S <pad> M M <pad> <pad> <pad> A L L <pad> <pad> <pad> <pad> L <pad> | | | <pad> <pad> <pad> <pad> S H H <pad> <pad> <pad> <pad> A R R <pad> <pad> P <pad> <pad> | <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> B <pad> <pad> L L <pad> <pad> <pad> <pad> <pad> O W W <pad> <pad> | | | <pad> <pad> <pad> <pad> <pad> <pad> <pad> H <pad> <pad> <pad> <pad> <pad> <pad> <pad> I G H H | | <pad> <pad> O N <pad> | | H <pad> I S S | | <pad> <pad> C H H <pad> <pad> <pad> E <pad> S S <pad> T T <pad> <pad> | | | <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>".split(' ').map(|s| s.to_string()).collect();
|
||||
assert_eq!(
|
||||
ctc_decoder.decode(id_to_string_result).unwrap(),
|
||||
"HIS INSTANCT PANIC WAS FOLLOWED BY A SMALL SHARP BLOW HIGH ON HIS CHEST ".to_string()
|
||||
);
|
||||
}
|
||||
}
|
@ -1,4 +1,5 @@
|
||||
pub mod bpe;
|
||||
pub mod ctc;
|
||||
pub mod wordpiece;
|
||||
|
||||
// Re-export these as decoders
|
||||
@ -8,6 +9,7 @@ pub use super::pre_tokenizers::metaspace;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::decoders::bpe::BPEDecoder;
|
||||
use crate::decoders::ctc::CTC;
|
||||
use crate::decoders::wordpiece::WordPiece;
|
||||
use crate::pre_tokenizers::byte_level::ByteLevel;
|
||||
use crate::pre_tokenizers::metaspace::Metaspace;
|
||||
@ -20,6 +22,7 @@ pub enum DecoderWrapper {
|
||||
ByteLevel(ByteLevel),
|
||||
WordPiece(WordPiece),
|
||||
Metaspace(Metaspace),
|
||||
CTC(CTC),
|
||||
}
|
||||
|
||||
impl Decoder for DecoderWrapper {
|
||||
@ -29,6 +32,7 @@ impl Decoder for DecoderWrapper {
|
||||
DecoderWrapper::ByteLevel(bl) => bl.decode(tokens),
|
||||
DecoderWrapper::Metaspace(ms) => ms.decode(tokens),
|
||||
DecoderWrapper::WordPiece(wp) => wp.decode(tokens),
|
||||
DecoderWrapper::CTC(ctc) => ctc.decode(tokens),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -37,3 +41,4 @@ impl_enum_from!(BPEDecoder, DecoderWrapper, BPE);
|
||||
impl_enum_from!(ByteLevel, DecoderWrapper, ByteLevel);
|
||||
impl_enum_from!(Metaspace, DecoderWrapper, Metaspace);
|
||||
impl_enum_from!(WordPiece, DecoderWrapper, WordPiece);
|
||||
impl_enum_from!(CTC, DecoderWrapper, CTC);
|
||||
|
@ -28,23 +28,26 @@ impl Default for WordPiece {
|
||||
}
|
||||
}
|
||||
}
|
||||
pub fn cleanup(dirty_input: String) -> String {
|
||||
dirty_input
|
||||
.replace(" .", ".")
|
||||
.replace(" ?", "?")
|
||||
.replace(" !", "!")
|
||||
.replace(" ,", ",")
|
||||
.replace(" ' ", "'")
|
||||
.replace(" n't", "n't")
|
||||
.replace(" 'm", "'m")
|
||||
.replace(" do not", " don't")
|
||||
.replace(" 's", "'s")
|
||||
.replace(" 've", "'ve")
|
||||
.replace(" 're", "'re")
|
||||
}
|
||||
|
||||
impl Decoder for WordPiece {
|
||||
fn decode(&self, tokens: Vec<String>) -> Result<String> {
|
||||
let mut output = tokens.join(" ").replace(&format!(" {}", self.prefix), "");
|
||||
if self.cleanup {
|
||||
output = output
|
||||
.replace(" .", ".")
|
||||
.replace(" ?", "?")
|
||||
.replace(" !", "!")
|
||||
.replace(" ,", ",")
|
||||
.replace(" ' ", "'")
|
||||
.replace(" n't", "n't")
|
||||
.replace(" 'm", "'m")
|
||||
.replace(" do not", " don't")
|
||||
.replace(" 's", "'s")
|
||||
.replace(" 've", "'ve")
|
||||
.replace(" 're", "'re");
|
||||
output = cleanup(output);
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
|
@ -167,7 +167,7 @@ impl PostProcessor for ByteLevel {
|
||||
}
|
||||
}
|
||||
|
||||
PostProcessor::default_process(encoding, pair_encoding, add_special_tokens)
|
||||
<dyn PostProcessor>::default_process(encoding, pair_encoding, add_special_tokens)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -41,7 +41,11 @@ impl PostProcessor for BertProcessing {
|
||||
add_special_tokens: bool,
|
||||
) -> Result<Encoding> {
|
||||
if !add_special_tokens {
|
||||
return PostProcessor::default_process(encoding, pair_encoding, add_special_tokens);
|
||||
return <dyn PostProcessor>::default_process(
|
||||
encoding,
|
||||
pair_encoding,
|
||||
add_special_tokens,
|
||||
);
|
||||
}
|
||||
|
||||
let ids = [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat();
|
||||
|
@ -74,7 +74,11 @@ impl PostProcessor for RobertaProcessing {
|
||||
}
|
||||
|
||||
if !add_special_tokens {
|
||||
return PostProcessor::default_process(encoding, pair_encoding, add_special_tokens);
|
||||
return <dyn PostProcessor>::default_process(
|
||||
encoding,
|
||||
pair_encoding,
|
||||
add_special_tokens,
|
||||
);
|
||||
}
|
||||
|
||||
let ids = [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat();
|
||||
|
@ -877,7 +877,7 @@ where
|
||||
let final_encoding = if let Some(processor) = &self.post_processor {
|
||||
processor.process(encoding, pair_encoding, add_special_tokens)?
|
||||
} else {
|
||||
PostProcessor::default_process(encoding, pair_encoding, add_special_tokens)?
|
||||
<dyn PostProcessor>::default_process(encoding, pair_encoding, add_special_tokens)?
|
||||
};
|
||||
|
||||
// 3. Then we pad if needed
|
||||
|
Reference in New Issue
Block a user