Add CTC Decoder for Wave2Vec models (#693)

* Rust - add a CTCDecoder as a seperate mod

* Adding bindings to Node + Python.

* Clippy update.

* Stub.

* Fixing roberta.json URLs.

* Moving test files to hf.co.

* Update cargo check and clippy to 1.52.

* Inner ':' actually is used for domains in sphinx.

Making `domain` work correctly was just too much work so I went the easy
way and have global roles for the custom rust extension.

* Update struct naming and docs

* Update changelog

Co-authored-by: Thomaub <github.thomaub@gmail.com>
Co-authored-by: Anthony MOI <m.anthony.moi@gmail.com>
This commit is contained in:
Nicolas Patry
2021-05-20 15:30:09 +02:00
committed by GitHub
parent e999a7b5f9
commit 2e2e7558f7
24 changed files with 353 additions and 50 deletions

View File

@ -27,7 +27,7 @@ $(DATA_DIR)/small.txt : $(DATA_DIR)/big.txt
$(DATA_DIR)/roberta.json :
$(dir_guard)
wget https://storage.googleapis.com/tokenizers/roberta.json -O $@
wget https://huggingface.co/roberta-large/raw/main/tokenizer.json -O $@
$(DATA_DIR)/tokenizer-wiki.json :
$(dir_guard)

View File

@ -36,3 +36,16 @@ export function metaspaceDecoder(replacement?: string, addPrefixSpace?: boolean)
* This suffix will be replaced by whitespaces during the decoding
*/
export function bpeDecoder(suffix?: string): Decoder;
/**
* Instantiate a new CTC Decoder
* @param [pad_token='pad'] The pad token used by CTC to delimit a new token.
* @param [word_delimiter_token='|'] The word delimiter token. It will be replaced by a space
* @param [cleanup=true] Whether to cleanup some tokenization artifacts.
* Mainly spaces before punctuation, and some abbreviated english forms.
*/
export function ctcDecoder(
pad_token?: string,
word_delimiter_token?: string,
cleanup?: boolean
): Decoder;

View File

@ -5,4 +5,5 @@ module.exports = {
wordPieceDecoder: native.decoders_WordPiece,
metaspaceDecoder: native.decoders_Metaspace,
bpeDecoder: native.decoders_BPEDecoder,
ctcDecoder: native.decoders_CTC,
};

View File

@ -1,4 +1,4 @@
import { bpeDecoder, metaspaceDecoder, wordPieceDecoder } from "./decoders";
import { bpeDecoder, ctcDecoder, metaspaceDecoder, wordPieceDecoder } from "./decoders";
describe("wordPieceDecoder", () => {
it("accepts `undefined` as first parameter", () => {
@ -31,3 +31,14 @@ describe("bpeDecoder", () => {
expect(bpeDecoder(undefined)).toBeDefined();
});
});
describe("ctcDecoder", () => {
it("accepts `undefined` as parameter", () => {
expect(ctcDecoder(undefined)).toBeDefined();
});
it("encodes correctly", () => {
expect(
ctcDecoder().decode(["<pad>", "h", "h", "e", "e", "l", "l", "<pad>", "l", "l", "o"])
).toEqual("hello");
});
});

View File

@ -97,11 +97,30 @@ fn bpe_decoder(mut cx: FunctionContext) -> JsResult<JsDecoder> {
Ok(decoder)
}
/// ctc_decoder(pad_token: String = "<pad>", word_delimiter_token: String = "|", cleanup = true)
fn ctc_decoder(mut cx: FunctionContext) -> JsResult<JsDecoder> {
let pad_token = cx
.extract_opt::<String>(0)?
.unwrap_or_else(|| String::from("<pad>"));
let word_delimiter_token = cx
.extract_opt::<String>(1)?
.unwrap_or_else(|| String::from("|"));
let cleanup = cx.extract_opt::<bool>(2)?.unwrap_or(true);
let mut decoder = JsDecoder::new::<_, JsDecoder, _>(&mut cx, vec![])?;
let guard = cx.lock();
decoder.borrow_mut(&guard).decoder = Some(Arc::new(
tk::decoders::ctc::CTC::new(pad_token, word_delimiter_token, cleanup).into(),
));
Ok(decoder)
}
/// Register everything here
pub fn register(m: &mut ModuleContext, prefix: &str) -> NeonResult<()> {
m.export_function(&format!("{}_ByteLevel", prefix), byte_level)?;
m.export_function(&format!("{}_WordPiece", prefix), wordpiece)?;
m.export_function(&format!("{}_Metaspace", prefix), metaspace)?;
m.export_function(&format!("{}_BPEDecoder", prefix), bpe_decoder)?;
m.export_function(&format!("{}_CTC", prefix), ctc_decoder)?;
Ok(())
}

View File

@ -4,6 +4,11 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [Unreleased]
### Added
- [#693]: Add a CTC Decoder for Wave2Vec models
## [0.10.2]
### Fixed
@ -309,6 +314,7 @@ delimiter (Works like `.split(delimiter)`)
- Fix a bug that was causing crashes in Python 3.5
[#694]: https://github.com/huggingface/tokenizers/pull/694
[#674]: https://github.com/huggingface/tokenizers/pull/674
[#656]: https://github.com/huggingface/tokenizers/pull/656
[#652]: https://github.com/huggingface/tokenizers/pull/652

View File

@ -30,4 +30,4 @@ $(DATA_DIR)/small.txt : $(DATA_DIR)/big.txt
$(DATA_DIR)/roberta.json :
$(dir_guard)
wget https://storage.googleapis.com/tokenizers/roberta.json -O $@
wget https://huggingface.co/roberta-large/raw/main/tokenizer.json -O $@

View File

@ -5,3 +5,4 @@ ByteLevel = decoders.ByteLevel
WordPiece = decoders.WordPiece
Metaspace = decoders.Metaspace
BPEDecoder = decoders.BPEDecoder
CTC = decoders.CTC

View File

@ -68,6 +68,35 @@ class ByteLevel(Decoder):
"""
pass
class CTC(Decoder):
"""
CTC Decoder
Args:
pad_token (:obj:`str`, `optional`, defaults to :obj:`<pad>`):
The pad token used by CTC to delimit a new token.
word_delimiter_token (:obj:`str`, `optional`, defaults to :obj:`|`):
The word delimiter token. It will be replaced by a <space>
cleanup (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether to cleanup some tokenization artifacts.
Mainly spaces before punctuation, and some abbreviated english forms.
"""
def __init__(self, pad_token="<pad>", word_delimiter_token="|", cleanup=True):
pass
def decode(self, tokens):
"""
Decode the given list of tokens to a final string
Args:
tokens (:obj:`List[str]`):
The list of tokens to decode
Returns:
:obj:`str`: The decoded string
"""
pass
class Metaspace(Decoder):
"""
Metaspace Decoder

View File

@ -8,6 +8,7 @@ use serde::de::Error;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use tk::decoders::bpe::BPEDecoder;
use tk::decoders::byte_level::ByteLevel;
use tk::decoders::ctc::CTC;
use tk::decoders::metaspace::Metaspace;
use tk::decoders::wordpiece::WordPiece;
use tk::decoders::DecoderWrapper;
@ -43,6 +44,7 @@ impl PyDecoder {
DecoderWrapper::WordPiece(_) => Py::new(py, (PyWordPieceDec {}, base))?.into_py(py),
DecoderWrapper::ByteLevel(_) => Py::new(py, (PyByteLevelDec {}, base))?.into_py(py),
DecoderWrapper::BPE(_) => Py::new(py, (PyBPEDecoder {}, base))?.into_py(py),
DecoderWrapper::CTC(_) => Py::new(py, (PyCTCDecoder {}, base))?.into_py(py),
},
})
}
@ -264,6 +266,65 @@ impl PyBPEDecoder {
}
}
/// CTC Decoder
///
/// Args:
/// pad_token (:obj:`str`, `optional`, defaults to :obj:`<pad>`):
/// The pad token used by CTC to delimit a new token.
/// word_delimiter_token (:obj:`str`, `optional`, defaults to :obj:`|`):
/// The word delimiter token. It will be replaced by a <space>
/// cleanup (:obj:`bool`, `optional`, defaults to :obj:`True`):
/// Whether to cleanup some tokenization artifacts.
/// Mainly spaces before punctuation, and some abbreviated english forms.
#[pyclass(extends=PyDecoder, module = "tokenizers.decoders", name=CTC)]
#[text_signature = "(self, pad_token=\"<pad>\", word_delimiter_token=\"|\", cleanup=True)"]
pub struct PyCTCDecoder {}
#[pymethods]
impl PyCTCDecoder {
#[getter]
fn get_pad_token(self_: PyRef<Self>) -> String {
getter!(self_, CTC, pad_token.clone())
}
#[setter]
fn set_pad_token(self_: PyRef<Self>, pad_token: String) {
setter!(self_, CTC, pad_token, pad_token);
}
#[getter]
fn get_word_delimiter_token(self_: PyRef<Self>) -> String {
getter!(self_, CTC, word_delimiter_token.clone())
}
#[setter]
fn set_word_delimiter_token(self_: PyRef<Self>, word_delimiter_token: String) {
setter!(self_, CTC, word_delimiter_token, word_delimiter_token);
}
#[getter]
fn get_cleanup(self_: PyRef<Self>) -> bool {
getter!(self_, CTC, cleanup)
}
#[setter]
fn set_cleanup(self_: PyRef<Self>, cleanup: bool) {
setter!(self_, CTC, cleanup, cleanup);
}
#[new]
#[args(
pad_token = "String::from(\"<pad>\")",
word_delimiter_token = "String::from(\"|\")",
cleanup = "true"
)]
fn new(pad_token: String, word_delimiter_token: String, cleanup: bool) -> (Self, PyDecoder) {
(
PyCTCDecoder {},
CTC::new(pad_token, word_delimiter_token, cleanup).into(),
)
}
}
#[derive(Clone)]
pub(crate) struct CustomDecoder {
inner: PyObject,

View File

@ -87,6 +87,7 @@ fn decoders(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<decoders::PyWordPieceDec>()?;
m.add_class::<decoders::PyMetaspaceDec>()?;
m.add_class::<decoders::PyBPEDecoder>()?;
m.add_class::<decoders::PyCTCDecoder>()?;
Ok(())
}

View File

@ -15,7 +15,7 @@ where
I: Iterator,
{
pub fn new(iter: I, length: Option<usize>) -> Self {
Self { iter, length }
Self { length, iter }
}
}

View File

@ -2,7 +2,7 @@ import pytest
import pickle
import json
from tokenizers.decoders import Decoder, ByteLevel, WordPiece, Metaspace, BPEDecoder
from tokenizers.decoders import Decoder, ByteLevel, WordPiece, Metaspace, BPEDecoder, CTC
class TestByteLevel:
@ -108,3 +108,45 @@ class TestBPEDecoder:
# Modify these
decoder.suffix = "</w>"
assert decoder.suffix == "</w>"
class TestCTCDecoder:
def test_instantiate(self):
assert CTC() is not None
assert CTC(pad_token="[PAD]") is not None
assert isinstance(CTC(), Decoder)
assert isinstance(CTC(), CTC)
assert isinstance(pickle.loads(pickle.dumps(CTC())), CTC)
def test_decoding(self):
decoder = CTC()
assert (
decoder.decode(
["<pad>", "<pad>", "h", "e", "e", "l", "l", "<pad>", "l", "o", "o", "o", "<pad>"]
)
== "hello"
)
decoder = CTC(pad_token="[PAD]")
assert (
decoder.decode(
["[PAD]", "[PAD]", "h", "e", "e", "l", "l", "[PAD]", "l", "o", "o", "o", "[PAD]"]
)
== "hello"
)
def test_can_modify(self):
decoder = CTC(pad_token="[PAD]")
assert decoder.pad_token == "[PAD]"
assert decoder.word_delimiter_token == "|"
assert decoder.cleanup == True
# Modify these
decoder.pad_token = "{pad}"
assert decoder.pad_token == "{pad}"
decoder.word_delimiter_token = "_"
assert decoder.word_delimiter_token == "_"
decoder.cleanup = False
assert decoder.cleanup == False

View File

@ -10,7 +10,7 @@ logger = sphinx.util.logging.getLogger(__name__)
class RustRef:
def __call__(self, name, rawtext, text, lineno, inliner, options={}, content=[]):
doctype = name.split(":")[1]
doctype = name.split("_")[1]
parts = text.split("::")
if text.startswith("~"):
@ -87,10 +87,10 @@ class RustRef:
def setup(app):
app.add_role("rust:struct", RustRef())
app.add_role("rust:func", RustRef())
app.add_role("rust:meth", RustRef())
app.add_role("rust:trait", RustRef())
app.add_role("rust_struct", RustRef())
app.add_role("rust_func", RustRef())
app.add_role("rust_meth", RustRef())
app.add_role("rust_trait", RustRef())
return {
"version": "0.1",

View File

@ -58,47 +58,47 @@
classmethod
static method
Tokenizer
:rust:struct:`~tokenizers::tokenizer::Tokenizer`
:rust_struct:`~tokenizers::tokenizer::Tokenizer`
Tokenizer.train
:rust:meth:`~tokenizers::tokenizer::Tokenizer::train`
:rust_meth:`~tokenizers::tokenizer::Tokenizer::train`
Tokenizer.save
:rust:meth:`~tokenizers::tokenizer::Tokenizer::save`
:rust_meth:`~tokenizers::tokenizer::Tokenizer::save`
Tokenizer.from_file
:rust:meth:`~tokenizers::tokenizer::Tokenizer::from_file`
:rust_meth:`~tokenizers::tokenizer::Tokenizer::from_file`
Tokenizer.encode
:rust:meth:`~tokenizers::tokenizer::Tokenizer::encode`
:rust_meth:`~tokenizers::tokenizer::Tokenizer::encode`
Tokenizer.encode_batch
:rust:meth:`~tokenizers::tokenizer::Tokenizer::encode_batch`
:rust_meth:`~tokenizers::tokenizer::Tokenizer::encode_batch`
Tokenizer.decode
:rust:meth:`~tokenizers::tokenizer::Tokenizer::decode`
:rust_meth:`~tokenizers::tokenizer::Tokenizer::decode`
Tokenizer.decode_batch
:rust:meth:`~tokenizers::tokenizer::Tokenizer::decode_batch`
:rust_meth:`~tokenizers::tokenizer::Tokenizer::decode_batch`
Tokenizer.token_to_id
:rust:meth:`~tokenizers::tokenizer::Tokenizer::token_to_id`
:rust_meth:`~tokenizers::tokenizer::Tokenizer::token_to_id`
Tokenizer.enable_padding
:rust:meth:`~tokenizers::tokenizer::Tokenizer::enable_padding`
:rust_meth:`~tokenizers::tokenizer::Tokenizer::enable_padding`
Encoding
:rust:struct:`~tokenizers::tokenizer::Encoding`
:rust_struct:`~tokenizers::tokenizer::Encoding`
TemplateProcessing
:rust:struct:`~tokenizers::processors::template::TemplateProcessing`
:rust_struct:`~tokenizers::processors::template::TemplateProcessing`
Normalizer
:rust:trait:`~tokenizers::tokenizer::Normalizer`
:rust_trait:`~tokenizers::tokenizer::Normalizer`
normalizers.Sequence
:rust:struct:`~tokenizers::normalizers::utils::Sequence`
:rust_struct:`~tokenizers::normalizers::utils::Sequence`
pre_tokenizers.Whitespace
:rust:struct:`~tokenizers::normalizers::whitespace::Whitespace`
:rust_struct:`~tokenizers::normalizers::whitespace::Whitespace`
PreTokenizer
:rust:trait:`~tokenizers::tokenizer::PreTokenizer`
:rust_trait:`~tokenizers::tokenizer::PreTokenizer`
models.BPE
:rust:struct:`~tokenizers::models::bpe::BPE`
:rust_struct:`~tokenizers::models::bpe::BPE`
models.Unigram
:rust:struct:`~tokenizers::models::unigram::Unigram`
:rust_struct:`~tokenizers::models::unigram::Unigram`
models.WordLevel
:rust:struct:`~tokenizers::models::wordlevel::WordLevel`
:rust_struct:`~tokenizers::models::wordlevel::WordLevel`
models.WordPiece
:rust:struct:`~tokenizers::models::wordpiece::WordPiece`
:rust_struct:`~tokenizers::models::wordpiece::WordPiece`
Decoder
:rust:trait:`~tokenizers::tokenizer::Decoder`
:rust_trait:`~tokenizers::tokenizer::Decoder`
.. entities:: node

View File

@ -44,7 +44,7 @@ Training the tokenizer
.. entities:: rust
BpeTrainer
:rust:struct:`~tokenizers::models::bpe::BpeTrainer`
:rust_struct:`~tokenizers::models::bpe::BpeTrainer`
vocab_size
:obj:`vocab_size`
min_frequency

View File

@ -55,7 +55,7 @@ $(DATA_DIR)/bert-% :
$(DATA_DIR)/unigram% :
$(dir_guard)
wget https://storage.googleapis.com/tokenizers/unigram$* -O $@
wget https://huggingface.co/Narsil/small/raw/main/unigram$* -O $@
$(DATA_DIR)/albert-base-v1-tokenizer.json :
$(dir_guard)
@ -70,7 +70,7 @@ $(DATA_DIR)/small.txt : $(DATA_DIR)/big.txt
$(DATA_DIR)/roberta.json :
$(dir_guard)
wget https://storage.googleapis.com/tokenizers/roberta.json -O $@
wget https://huggingface.co/Narsil/small/raw/main/roberta.json -O $@
$(DATA_DIR)/tokenizer-wiki.json :
$(dir_guard)

View File

@ -0,0 +1,103 @@
use crate::decoders::wordpiece;
use crate::tokenizer::{Decoder, Result};
use itertools::Itertools;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
/// The CTC (Connectionist Temporal Classification) decoder takes care
/// of sanitizing a list of inputs token.
/// Due to some alignement problem the output of some models can come
/// with duplicated token.
#[serde(tag = "type")]
#[non_exhaustive]
pub struct CTC {
/// The pad token used by CTC to delimit a new token.
pub pad_token: String,
/// The word delimiter token. It will be replaced by a <space>
pub word_delimiter_token: String,
/// Whether to cleanup some tokenization artifacts.
/// Mainly spaces before punctuation, and some abbreviated english forms.
pub cleanup: bool,
}
impl CTC {
pub fn new(pad_token: String, word_delimiter_token: String, cleanup: bool) -> Self {
Self {
pad_token,
word_delimiter_token,
cleanup,
}
}
}
impl Default for CTC {
fn default() -> Self {
Self {
pad_token: "<pad>".to_string(),
word_delimiter_token: "|".to_string(),
cleanup: true,
}
}
}
impl Decoder for CTC {
fn decode(&self, tokens: Vec<String>) -> Result<String> {
let mut output = tokens
.into_iter()
.dedup()
.join("")
.replace(&self.pad_token, "");
if self.cleanup {
output = wordpiece::cleanup(output).replace(&self.word_delimiter_token, " ");
}
Ok(output)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn handmade_sample() {
let ctc_decoder = CTC::default();
let id_to_string_result = "<pad> <pad> h e e l l <pad> l o o o <pad>"
.split(' ')
.map(|s| s.to_string())
.collect();
assert_eq!(
ctc_decoder.decode(id_to_string_result).unwrap(),
"hello".to_string()
);
}
#[test]
fn handmade_with_delimiter_sample() {
let ctc_decoder = CTC::default();
let id_to_string_result = "<pad> <pad> h e e l l <pad> l o o o <pad> <pad> | <pad> w o o o r <pad> <pad> l l d <pad> <pad> <pad> <pad>"
.split(' ')
.map(|s| s.to_string())
.collect();
assert_eq!(
ctc_decoder.decode(id_to_string_result).unwrap(),
"hello world".to_string()
);
}
#[test]
fn librispeech_sample() {
let ctc_decoder = CTC::default();
let id_to_string_result = "<pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> A | | <pad> M <pad> <pad> <pad> <pad> A <pad> <pad> N <pad> <pad> <pad> | | | <pad> <pad> <pad> <pad> S <pad> <pad> <pad> A I <pad> D D | | T T <pad> O <pad> | | T H E E | | | <pad> U U <pad> N N <pad> I <pad> <pad> V <pad> <pad> <pad> E R R <pad> <pad> <pad> S E E | | <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> S S <pad> <pad> <pad> <pad> I <pad> R R <pad> <pad> | | | <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> I <pad> <pad> <pad> | <pad> <pad> <pad> E X <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> I <pad> S <pad> <pad> T <pad> <pad> | | <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>".split(' ').map(|s| s.to_string()).collect();
assert_eq!(
ctc_decoder.decode(id_to_string_result).unwrap(),
"A MAN SAID TO THE UNIVERSE SIR I EXIST ".to_string()
);
}
#[test]
fn another_librispeech_sample() {
let ctc_decoder = CTC::default();
let id_to_string_result = "<pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> H <pad> I <pad> S S | | <pad> <pad> <pad> I N <pad> <pad> S <pad> T T <pad> <pad> A N C C T <pad> | | | | | <pad> <pad> <pad> <pad> P <pad> <pad> <pad> <pad> A <pad> <pad> N N N <pad> <pad> I <pad> C <pad> <pad> | | <pad> W <pad> <pad> A S <pad> | | <pad> <pad> <pad> F <pad> <pad> O L <pad> <pad> L L O O W E E D | | <pad> B <pad> <pad> <pad> Y <pad> | | | A | | <pad> S S S <pad> M M <pad> <pad> <pad> A L L <pad> <pad> <pad> <pad> L <pad> | | | <pad> <pad> <pad> <pad> S H H <pad> <pad> <pad> <pad> A R R <pad> <pad> P <pad> <pad> | <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> B <pad> <pad> L L <pad> <pad> <pad> <pad> <pad> O W W <pad> <pad> | | | <pad> <pad> <pad> <pad> <pad> <pad> <pad> H <pad> <pad> <pad> <pad> <pad> <pad> <pad> I G H H | | <pad> <pad> O N <pad> | | H <pad> I S S | | <pad> <pad> C H H <pad> <pad> <pad> E <pad> S S <pad> T T <pad> <pad> | | | <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>".split(' ').map(|s| s.to_string()).collect();
assert_eq!(
ctc_decoder.decode(id_to_string_result).unwrap(),
"HIS INSTANCT PANIC WAS FOLLOWED BY A SMALL SHARP BLOW HIGH ON HIS CHEST ".to_string()
);
}
}

View File

@ -1,4 +1,5 @@
pub mod bpe;
pub mod ctc;
pub mod wordpiece;
// Re-export these as decoders
@ -8,6 +9,7 @@ pub use super::pre_tokenizers::metaspace;
use serde::{Deserialize, Serialize};
use crate::decoders::bpe::BPEDecoder;
use crate::decoders::ctc::CTC;
use crate::decoders::wordpiece::WordPiece;
use crate::pre_tokenizers::byte_level::ByteLevel;
use crate::pre_tokenizers::metaspace::Metaspace;
@ -20,6 +22,7 @@ pub enum DecoderWrapper {
ByteLevel(ByteLevel),
WordPiece(WordPiece),
Metaspace(Metaspace),
CTC(CTC),
}
impl Decoder for DecoderWrapper {
@ -29,6 +32,7 @@ impl Decoder for DecoderWrapper {
DecoderWrapper::ByteLevel(bl) => bl.decode(tokens),
DecoderWrapper::Metaspace(ms) => ms.decode(tokens),
DecoderWrapper::WordPiece(wp) => wp.decode(tokens),
DecoderWrapper::CTC(ctc) => ctc.decode(tokens),
}
}
}
@ -37,3 +41,4 @@ impl_enum_from!(BPEDecoder, DecoderWrapper, BPE);
impl_enum_from!(ByteLevel, DecoderWrapper, ByteLevel);
impl_enum_from!(Metaspace, DecoderWrapper, Metaspace);
impl_enum_from!(WordPiece, DecoderWrapper, WordPiece);
impl_enum_from!(CTC, DecoderWrapper, CTC);

View File

@ -28,12 +28,8 @@ impl Default for WordPiece {
}
}
}
impl Decoder for WordPiece {
fn decode(&self, tokens: Vec<String>) -> Result<String> {
let mut output = tokens.join(" ").replace(&format!(" {}", self.prefix), "");
if self.cleanup {
output = output
pub fn cleanup(dirty_input: String) -> String {
dirty_input
.replace(" .", ".")
.replace(" ?", "?")
.replace(" !", "!")
@ -44,7 +40,14 @@ impl Decoder for WordPiece {
.replace(" do not", " don't")
.replace(" 's", "'s")
.replace(" 've", "'ve")
.replace(" 're", "'re");
.replace(" 're", "'re")
}
impl Decoder for WordPiece {
fn decode(&self, tokens: Vec<String>) -> Result<String> {
let mut output = tokens.join(" ").replace(&format!(" {}", self.prefix), "");
if self.cleanup {
output = cleanup(output);
}
Ok(output)

View File

@ -167,7 +167,7 @@ impl PostProcessor for ByteLevel {
}
}
PostProcessor::default_process(encoding, pair_encoding, add_special_tokens)
<dyn PostProcessor>::default_process(encoding, pair_encoding, add_special_tokens)
}
}

View File

@ -41,7 +41,11 @@ impl PostProcessor for BertProcessing {
add_special_tokens: bool,
) -> Result<Encoding> {
if !add_special_tokens {
return PostProcessor::default_process(encoding, pair_encoding, add_special_tokens);
return <dyn PostProcessor>::default_process(
encoding,
pair_encoding,
add_special_tokens,
);
}
let ids = [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat();

View File

@ -74,7 +74,11 @@ impl PostProcessor for RobertaProcessing {
}
if !add_special_tokens {
return PostProcessor::default_process(encoding, pair_encoding, add_special_tokens);
return <dyn PostProcessor>::default_process(
encoding,
pair_encoding,
add_special_tokens,
);
}
let ids = [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat();

View File

@ -877,7 +877,7 @@ where
let final_encoding = if let Some(processor) = &self.post_processor {
processor.process(encoding, pair_encoding, add_special_tokens)?
} else {
PostProcessor::default_process(encoding, pair_encoding, add_special_tokens)?
<dyn PostProcessor>::default_process(encoding, pair_encoding, add_special_tokens)?
};
// 3. Then we pad if needed