diff --git a/bindings/node/Makefile b/bindings/node/Makefile index 3f09b78a..19404f99 100644 --- a/bindings/node/Makefile +++ b/bindings/node/Makefile @@ -27,7 +27,7 @@ $(DATA_DIR)/small.txt : $(DATA_DIR)/big.txt $(DATA_DIR)/roberta.json : $(dir_guard) - wget https://storage.googleapis.com/tokenizers/roberta.json -O $@ + wget https://huggingface.co/roberta-large/raw/main/tokenizer.json -O $@ $(DATA_DIR)/tokenizer-wiki.json : $(dir_guard) diff --git a/bindings/node/lib/bindings/decoders.d.ts b/bindings/node/lib/bindings/decoders.d.ts index fdbb4047..9841f8eb 100644 --- a/bindings/node/lib/bindings/decoders.d.ts +++ b/bindings/node/lib/bindings/decoders.d.ts @@ -36,3 +36,16 @@ export function metaspaceDecoder(replacement?: string, addPrefixSpace?: boolean) * This suffix will be replaced by whitespaces during the decoding */ export function bpeDecoder(suffix?: string): Decoder; + +/** + * Instantiate a new CTC Decoder + * @param [pad_token='pad'] The pad token used by CTC to delimit a new token. + * @param [word_delimiter_token='|'] The word delimiter token. It will be replaced by a space + * @param [cleanup=true] Whether to cleanup some tokenization artifacts. + * Mainly spaces before punctuation, and some abbreviated english forms. + */ +export function ctcDecoder( + pad_token?: string, + word_delimiter_token?: string, + cleanup?: boolean +): Decoder; diff --git a/bindings/node/lib/bindings/decoders.js b/bindings/node/lib/bindings/decoders.js index b7eed529..8789a448 100644 --- a/bindings/node/lib/bindings/decoders.js +++ b/bindings/node/lib/bindings/decoders.js @@ -5,4 +5,5 @@ module.exports = { wordPieceDecoder: native.decoders_WordPiece, metaspaceDecoder: native.decoders_Metaspace, bpeDecoder: native.decoders_BPEDecoder, + ctcDecoder: native.decoders_CTC, }; diff --git a/bindings/node/lib/bindings/decoders.test.ts b/bindings/node/lib/bindings/decoders.test.ts index e6aa0e67..b23f243f 100644 --- a/bindings/node/lib/bindings/decoders.test.ts +++ b/bindings/node/lib/bindings/decoders.test.ts @@ -1,4 +1,4 @@ -import { bpeDecoder, metaspaceDecoder, wordPieceDecoder } from "./decoders"; +import { bpeDecoder, ctcDecoder, metaspaceDecoder, wordPieceDecoder } from "./decoders"; describe("wordPieceDecoder", () => { it("accepts `undefined` as first parameter", () => { @@ -31,3 +31,14 @@ describe("bpeDecoder", () => { expect(bpeDecoder(undefined)).toBeDefined(); }); }); + +describe("ctcDecoder", () => { + it("accepts `undefined` as parameter", () => { + expect(ctcDecoder(undefined)).toBeDefined(); + }); + it("encodes correctly", () => { + expect( + ctcDecoder().decode(["", "h", "h", "e", "e", "l", "l", "", "l", "l", "o"]) + ).toEqual("hello"); + }); +}); diff --git a/bindings/node/native/src/decoders.rs b/bindings/node/native/src/decoders.rs index 88a3bbbe..9a01bc4d 100644 --- a/bindings/node/native/src/decoders.rs +++ b/bindings/node/native/src/decoders.rs @@ -97,11 +97,30 @@ fn bpe_decoder(mut cx: FunctionContext) -> JsResult { Ok(decoder) } +/// ctc_decoder(pad_token: String = "", word_delimiter_token: String = "|", cleanup = true) +fn ctc_decoder(mut cx: FunctionContext) -> JsResult { + let pad_token = cx + .extract_opt::(0)? + .unwrap_or_else(|| String::from("")); + let word_delimiter_token = cx + .extract_opt::(1)? + .unwrap_or_else(|| String::from("|")); + let cleanup = cx.extract_opt::(2)?.unwrap_or(true); + + let mut decoder = JsDecoder::new::<_, JsDecoder, _>(&mut cx, vec![])?; + let guard = cx.lock(); + decoder.borrow_mut(&guard).decoder = Some(Arc::new( + tk::decoders::ctc::CTC::new(pad_token, word_delimiter_token, cleanup).into(), + )); + Ok(decoder) +} + /// Register everything here pub fn register(m: &mut ModuleContext, prefix: &str) -> NeonResult<()> { m.export_function(&format!("{}_ByteLevel", prefix), byte_level)?; m.export_function(&format!("{}_WordPiece", prefix), wordpiece)?; m.export_function(&format!("{}_Metaspace", prefix), metaspace)?; m.export_function(&format!("{}_BPEDecoder", prefix), bpe_decoder)?; + m.export_function(&format!("{}_CTC", prefix), ctc_decoder)?; Ok(()) } diff --git a/bindings/python/CHANGELOG.md b/bindings/python/CHANGELOG.md index 2741385b..4fea9a83 100644 --- a/bindings/python/CHANGELOG.md +++ b/bindings/python/CHANGELOG.md @@ -4,6 +4,11 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [Unreleased] + +### Added +- [#693]: Add a CTC Decoder for Wave2Vec models + ## [0.10.2] ### Fixed @@ -309,6 +314,7 @@ delimiter (Works like `.split(delimiter)`) - Fix a bug that was causing crashes in Python 3.5 +[#694]: https://github.com/huggingface/tokenizers/pull/694 [#674]: https://github.com/huggingface/tokenizers/pull/674 [#656]: https://github.com/huggingface/tokenizers/pull/656 [#652]: https://github.com/huggingface/tokenizers/pull/652 diff --git a/bindings/python/Makefile b/bindings/python/Makefile index 7ad70943..a3e262d6 100644 --- a/bindings/python/Makefile +++ b/bindings/python/Makefile @@ -30,4 +30,4 @@ $(DATA_DIR)/small.txt : $(DATA_DIR)/big.txt $(DATA_DIR)/roberta.json : $(dir_guard) - wget https://storage.googleapis.com/tokenizers/roberta.json -O $@ + wget https://huggingface.co/roberta-large/raw/main/tokenizer.json -O $@ diff --git a/bindings/python/py_src/tokenizers/decoders/__init__.py b/bindings/python/py_src/tokenizers/decoders/__init__.py index edb99cfe..ce1af33f 100644 --- a/bindings/python/py_src/tokenizers/decoders/__init__.py +++ b/bindings/python/py_src/tokenizers/decoders/__init__.py @@ -5,3 +5,4 @@ ByteLevel = decoders.ByteLevel WordPiece = decoders.WordPiece Metaspace = decoders.Metaspace BPEDecoder = decoders.BPEDecoder +CTC = decoders.CTC diff --git a/bindings/python/py_src/tokenizers/decoders/__init__.pyi b/bindings/python/py_src/tokenizers/decoders/__init__.pyi index a6886742..832c9e71 100644 --- a/bindings/python/py_src/tokenizers/decoders/__init__.pyi +++ b/bindings/python/py_src/tokenizers/decoders/__init__.pyi @@ -68,6 +68,35 @@ class ByteLevel(Decoder): """ pass +class CTC(Decoder): + """ + CTC Decoder + + Args: + pad_token (:obj:`str`, `optional`, defaults to :obj:``): + The pad token used by CTC to delimit a new token. + word_delimiter_token (:obj:`str`, `optional`, defaults to :obj:`|`): + The word delimiter token. It will be replaced by a + cleanup (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether to cleanup some tokenization artifacts. + Mainly spaces before punctuation, and some abbreviated english forms. + """ + + def __init__(self, pad_token="", word_delimiter_token="|", cleanup=True): + pass + def decode(self, tokens): + """ + Decode the given list of tokens to a final string + + Args: + tokens (:obj:`List[str]`): + The list of tokens to decode + + Returns: + :obj:`str`: The decoded string + """ + pass + class Metaspace(Decoder): """ Metaspace Decoder diff --git a/bindings/python/src/decoders.rs b/bindings/python/src/decoders.rs index f27ed5e1..5f15838c 100644 --- a/bindings/python/src/decoders.rs +++ b/bindings/python/src/decoders.rs @@ -8,6 +8,7 @@ use serde::de::Error; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use tk::decoders::bpe::BPEDecoder; use tk::decoders::byte_level::ByteLevel; +use tk::decoders::ctc::CTC; use tk::decoders::metaspace::Metaspace; use tk::decoders::wordpiece::WordPiece; use tk::decoders::DecoderWrapper; @@ -43,6 +44,7 @@ impl PyDecoder { DecoderWrapper::WordPiece(_) => Py::new(py, (PyWordPieceDec {}, base))?.into_py(py), DecoderWrapper::ByteLevel(_) => Py::new(py, (PyByteLevelDec {}, base))?.into_py(py), DecoderWrapper::BPE(_) => Py::new(py, (PyBPEDecoder {}, base))?.into_py(py), + DecoderWrapper::CTC(_) => Py::new(py, (PyCTCDecoder {}, base))?.into_py(py), }, }) } @@ -264,6 +266,65 @@ impl PyBPEDecoder { } } +/// CTC Decoder +/// +/// Args: +/// pad_token (:obj:`str`, `optional`, defaults to :obj:``): +/// The pad token used by CTC to delimit a new token. +/// word_delimiter_token (:obj:`str`, `optional`, defaults to :obj:`|`): +/// The word delimiter token. It will be replaced by a +/// cleanup (:obj:`bool`, `optional`, defaults to :obj:`True`): +/// Whether to cleanup some tokenization artifacts. +/// Mainly spaces before punctuation, and some abbreviated english forms. +#[pyclass(extends=PyDecoder, module = "tokenizers.decoders", name=CTC)] +#[text_signature = "(self, pad_token=\"\", word_delimiter_token=\"|\", cleanup=True)"] +pub struct PyCTCDecoder {} +#[pymethods] +impl PyCTCDecoder { + #[getter] + fn get_pad_token(self_: PyRef) -> String { + getter!(self_, CTC, pad_token.clone()) + } + + #[setter] + fn set_pad_token(self_: PyRef, pad_token: String) { + setter!(self_, CTC, pad_token, pad_token); + } + + #[getter] + fn get_word_delimiter_token(self_: PyRef) -> String { + getter!(self_, CTC, word_delimiter_token.clone()) + } + + #[setter] + fn set_word_delimiter_token(self_: PyRef, word_delimiter_token: String) { + setter!(self_, CTC, word_delimiter_token, word_delimiter_token); + } + + #[getter] + fn get_cleanup(self_: PyRef) -> bool { + getter!(self_, CTC, cleanup) + } + + #[setter] + fn set_cleanup(self_: PyRef, cleanup: bool) { + setter!(self_, CTC, cleanup, cleanup); + } + + #[new] + #[args( + pad_token = "String::from(\"\")", + word_delimiter_token = "String::from(\"|\")", + cleanup = "true" + )] + fn new(pad_token: String, word_delimiter_token: String, cleanup: bool) -> (Self, PyDecoder) { + ( + PyCTCDecoder {}, + CTC::new(pad_token, word_delimiter_token, cleanup).into(), + ) + } +} + #[derive(Clone)] pub(crate) struct CustomDecoder { inner: PyObject, diff --git a/bindings/python/src/lib.rs b/bindings/python/src/lib.rs index 2301a850..5bae12ba 100644 --- a/bindings/python/src/lib.rs +++ b/bindings/python/src/lib.rs @@ -87,6 +87,7 @@ fn decoders(_py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; Ok(()) } diff --git a/bindings/python/src/utils/iterators.rs b/bindings/python/src/utils/iterators.rs index 6b5a4175..0715df51 100644 --- a/bindings/python/src/utils/iterators.rs +++ b/bindings/python/src/utils/iterators.rs @@ -15,7 +15,7 @@ where I: Iterator, { pub fn new(iter: I, length: Option) -> Self { - Self { iter, length } + Self { length, iter } } } diff --git a/bindings/python/tests/bindings/test_decoders.py b/bindings/python/tests/bindings/test_decoders.py index ab427bf7..41e7187e 100644 --- a/bindings/python/tests/bindings/test_decoders.py +++ b/bindings/python/tests/bindings/test_decoders.py @@ -2,7 +2,7 @@ import pytest import pickle import json -from tokenizers.decoders import Decoder, ByteLevel, WordPiece, Metaspace, BPEDecoder +from tokenizers.decoders import Decoder, ByteLevel, WordPiece, Metaspace, BPEDecoder, CTC class TestByteLevel: @@ -108,3 +108,45 @@ class TestBPEDecoder: # Modify these decoder.suffix = "" assert decoder.suffix == "" + + +class TestCTCDecoder: + def test_instantiate(self): + assert CTC() is not None + assert CTC(pad_token="[PAD]") is not None + assert isinstance(CTC(), Decoder) + assert isinstance(CTC(), CTC) + assert isinstance(pickle.loads(pickle.dumps(CTC())), CTC) + + def test_decoding(self): + decoder = CTC() + assert ( + decoder.decode( + ["", "", "h", "e", "e", "l", "l", "", "l", "o", "o", "o", ""] + ) + == "hello" + ) + decoder = CTC(pad_token="[PAD]") + assert ( + decoder.decode( + ["[PAD]", "[PAD]", "h", "e", "e", "l", "l", "[PAD]", "l", "o", "o", "o", "[PAD]"] + ) + == "hello" + ) + + def test_can_modify(self): + decoder = CTC(pad_token="[PAD]") + + assert decoder.pad_token == "[PAD]" + assert decoder.word_delimiter_token == "|" + assert decoder.cleanup == True + + # Modify these + decoder.pad_token = "{pad}" + assert decoder.pad_token == "{pad}" + + decoder.word_delimiter_token = "_" + assert decoder.word_delimiter_token == "_" + + decoder.cleanup = False + assert decoder.cleanup == False diff --git a/docs/source/_ext/rust_doc.py b/docs/source/_ext/rust_doc.py index ca549d49..47098341 100644 --- a/docs/source/_ext/rust_doc.py +++ b/docs/source/_ext/rust_doc.py @@ -10,7 +10,7 @@ logger = sphinx.util.logging.getLogger(__name__) class RustRef: def __call__(self, name, rawtext, text, lineno, inliner, options={}, content=[]): - doctype = name.split(":")[1] + doctype = name.split("_")[1] parts = text.split("::") if text.startswith("~"): @@ -87,10 +87,10 @@ class RustRef: def setup(app): - app.add_role("rust:struct", RustRef()) - app.add_role("rust:func", RustRef()) - app.add_role("rust:meth", RustRef()) - app.add_role("rust:trait", RustRef()) + app.add_role("rust_struct", RustRef()) + app.add_role("rust_func", RustRef()) + app.add_role("rust_meth", RustRef()) + app.add_role("rust_trait", RustRef()) return { "version": "0.1", diff --git a/docs/source/entities.inc b/docs/source/entities.inc index cd3f5048..d8004abf 100644 --- a/docs/source/entities.inc +++ b/docs/source/entities.inc @@ -58,47 +58,47 @@ classmethod static method Tokenizer - :rust:struct:`~tokenizers::tokenizer::Tokenizer` + :rust_struct:`~tokenizers::tokenizer::Tokenizer` Tokenizer.train - :rust:meth:`~tokenizers::tokenizer::Tokenizer::train` + :rust_meth:`~tokenizers::tokenizer::Tokenizer::train` Tokenizer.save - :rust:meth:`~tokenizers::tokenizer::Tokenizer::save` + :rust_meth:`~tokenizers::tokenizer::Tokenizer::save` Tokenizer.from_file - :rust:meth:`~tokenizers::tokenizer::Tokenizer::from_file` + :rust_meth:`~tokenizers::tokenizer::Tokenizer::from_file` Tokenizer.encode - :rust:meth:`~tokenizers::tokenizer::Tokenizer::encode` + :rust_meth:`~tokenizers::tokenizer::Tokenizer::encode` Tokenizer.encode_batch - :rust:meth:`~tokenizers::tokenizer::Tokenizer::encode_batch` + :rust_meth:`~tokenizers::tokenizer::Tokenizer::encode_batch` Tokenizer.decode - :rust:meth:`~tokenizers::tokenizer::Tokenizer::decode` + :rust_meth:`~tokenizers::tokenizer::Tokenizer::decode` Tokenizer.decode_batch - :rust:meth:`~tokenizers::tokenizer::Tokenizer::decode_batch` + :rust_meth:`~tokenizers::tokenizer::Tokenizer::decode_batch` Tokenizer.token_to_id - :rust:meth:`~tokenizers::tokenizer::Tokenizer::token_to_id` + :rust_meth:`~tokenizers::tokenizer::Tokenizer::token_to_id` Tokenizer.enable_padding - :rust:meth:`~tokenizers::tokenizer::Tokenizer::enable_padding` + :rust_meth:`~tokenizers::tokenizer::Tokenizer::enable_padding` Encoding - :rust:struct:`~tokenizers::tokenizer::Encoding` + :rust_struct:`~tokenizers::tokenizer::Encoding` TemplateProcessing - :rust:struct:`~tokenizers::processors::template::TemplateProcessing` + :rust_struct:`~tokenizers::processors::template::TemplateProcessing` Normalizer - :rust:trait:`~tokenizers::tokenizer::Normalizer` + :rust_trait:`~tokenizers::tokenizer::Normalizer` normalizers.Sequence - :rust:struct:`~tokenizers::normalizers::utils::Sequence` + :rust_struct:`~tokenizers::normalizers::utils::Sequence` pre_tokenizers.Whitespace - :rust:struct:`~tokenizers::normalizers::whitespace::Whitespace` + :rust_struct:`~tokenizers::normalizers::whitespace::Whitespace` PreTokenizer - :rust:trait:`~tokenizers::tokenizer::PreTokenizer` + :rust_trait:`~tokenizers::tokenizer::PreTokenizer` models.BPE - :rust:struct:`~tokenizers::models::bpe::BPE` + :rust_struct:`~tokenizers::models::bpe::BPE` models.Unigram - :rust:struct:`~tokenizers::models::unigram::Unigram` + :rust_struct:`~tokenizers::models::unigram::Unigram` models.WordLevel - :rust:struct:`~tokenizers::models::wordlevel::WordLevel` + :rust_struct:`~tokenizers::models::wordlevel::WordLevel` models.WordPiece - :rust:struct:`~tokenizers::models::wordpiece::WordPiece` + :rust_struct:`~tokenizers::models::wordpiece::WordPiece` Decoder - :rust:trait:`~tokenizers::tokenizer::Decoder` + :rust_trait:`~tokenizers::tokenizer::Decoder` .. entities:: node diff --git a/docs/source/quicktour.rst b/docs/source/quicktour.rst index 309d83de..a8ad2600 100644 --- a/docs/source/quicktour.rst +++ b/docs/source/quicktour.rst @@ -44,7 +44,7 @@ Training the tokenizer .. entities:: rust BpeTrainer - :rust:struct:`~tokenizers::models::bpe::BpeTrainer` + :rust_struct:`~tokenizers::models::bpe::BpeTrainer` vocab_size :obj:`vocab_size` min_frequency diff --git a/tokenizers/Makefile b/tokenizers/Makefile index 29173d75..486f5a56 100644 --- a/tokenizers/Makefile +++ b/tokenizers/Makefile @@ -55,7 +55,7 @@ $(DATA_DIR)/bert-% : $(DATA_DIR)/unigram% : $(dir_guard) - wget https://storage.googleapis.com/tokenizers/unigram$* -O $@ + wget https://huggingface.co/Narsil/small/raw/main/unigram$* -O $@ $(DATA_DIR)/albert-base-v1-tokenizer.json : $(dir_guard) @@ -70,7 +70,7 @@ $(DATA_DIR)/small.txt : $(DATA_DIR)/big.txt $(DATA_DIR)/roberta.json : $(dir_guard) - wget https://storage.googleapis.com/tokenizers/roberta.json -O $@ + wget https://huggingface.co/Narsil/small/raw/main/roberta.json -O $@ $(DATA_DIR)/tokenizer-wiki.json : $(dir_guard) diff --git a/tokenizers/src/decoders/ctc.rs b/tokenizers/src/decoders/ctc.rs new file mode 100644 index 00000000..17f7ba16 --- /dev/null +++ b/tokenizers/src/decoders/ctc.rs @@ -0,0 +1,103 @@ +use crate::decoders::wordpiece; +use crate::tokenizer::{Decoder, Result}; + +use itertools::Itertools; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +/// The CTC (Connectionist Temporal Classification) decoder takes care +/// of sanitizing a list of inputs token. +/// Due to some alignement problem the output of some models can come +/// with duplicated token. +#[serde(tag = "type")] +#[non_exhaustive] +pub struct CTC { + /// The pad token used by CTC to delimit a new token. + pub pad_token: String, + /// The word delimiter token. It will be replaced by a + pub word_delimiter_token: String, + /// Whether to cleanup some tokenization artifacts. + /// Mainly spaces before punctuation, and some abbreviated english forms. + pub cleanup: bool, +} + +impl CTC { + pub fn new(pad_token: String, word_delimiter_token: String, cleanup: bool) -> Self { + Self { + pad_token, + word_delimiter_token, + cleanup, + } + } +} + +impl Default for CTC { + fn default() -> Self { + Self { + pad_token: "".to_string(), + word_delimiter_token: "|".to_string(), + cleanup: true, + } + } +} + +impl Decoder for CTC { + fn decode(&self, tokens: Vec) -> Result { + let mut output = tokens + .into_iter() + .dedup() + .join("") + .replace(&self.pad_token, ""); + if self.cleanup { + output = wordpiece::cleanup(output).replace(&self.word_delimiter_token, " "); + } + Ok(output) + } +} + +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn handmade_sample() { + let ctc_decoder = CTC::default(); + let id_to_string_result = " h e e l l l o o o " + .split(' ') + .map(|s| s.to_string()) + .collect(); + assert_eq!( + ctc_decoder.decode(id_to_string_result).unwrap(), + "hello".to_string() + ); + } + #[test] + fn handmade_with_delimiter_sample() { + let ctc_decoder = CTC::default(); + let id_to_string_result = " h e e l l l o o o | w o o o r l l d " + .split(' ') + .map(|s| s.to_string()) + .collect(); + assert_eq!( + ctc_decoder.decode(id_to_string_result).unwrap(), + "hello world".to_string() + ); + } + #[test] + fn librispeech_sample() { + let ctc_decoder = CTC::default(); + let id_to_string_result = " A | | M A N | | | S A I D D | | T T O | | T H E E | | | U U N N I V E R R S E E | | S S I R R | | | I | E X I S T | | ".split(' ').map(|s| s.to_string()).collect(); + assert_eq!( + ctc_decoder.decode(id_to_string_result).unwrap(), + "A MAN SAID TO THE UNIVERSE SIR I EXIST ".to_string() + ); + } + #[test] + fn another_librispeech_sample() { + let ctc_decoder = CTC::default(); + let id_to_string_result = " H I S S | | I N S T T A N C C T | | | | | P A N N N I C | | W A S | | F O L L L O O W E E D | | B Y | | | A | | S S S M M A L L L | | | S H H A R R P | B L L O W W | | | H I G H H | | O N | | H I S S | | C H H E S S T T | | | ".split(' ').map(|s| s.to_string()).collect(); + assert_eq!( + ctc_decoder.decode(id_to_string_result).unwrap(), + "HIS INSTANCT PANIC WAS FOLLOWED BY A SMALL SHARP BLOW HIGH ON HIS CHEST ".to_string() + ); + } +} diff --git a/tokenizers/src/decoders/mod.rs b/tokenizers/src/decoders/mod.rs index 261852f4..9dac4274 100644 --- a/tokenizers/src/decoders/mod.rs +++ b/tokenizers/src/decoders/mod.rs @@ -1,4 +1,5 @@ pub mod bpe; +pub mod ctc; pub mod wordpiece; // Re-export these as decoders @@ -8,6 +9,7 @@ pub use super::pre_tokenizers::metaspace; use serde::{Deserialize, Serialize}; use crate::decoders::bpe::BPEDecoder; +use crate::decoders::ctc::CTC; use crate::decoders::wordpiece::WordPiece; use crate::pre_tokenizers::byte_level::ByteLevel; use crate::pre_tokenizers::metaspace::Metaspace; @@ -20,6 +22,7 @@ pub enum DecoderWrapper { ByteLevel(ByteLevel), WordPiece(WordPiece), Metaspace(Metaspace), + CTC(CTC), } impl Decoder for DecoderWrapper { @@ -29,6 +32,7 @@ impl Decoder for DecoderWrapper { DecoderWrapper::ByteLevel(bl) => bl.decode(tokens), DecoderWrapper::Metaspace(ms) => ms.decode(tokens), DecoderWrapper::WordPiece(wp) => wp.decode(tokens), + DecoderWrapper::CTC(ctc) => ctc.decode(tokens), } } } @@ -37,3 +41,4 @@ impl_enum_from!(BPEDecoder, DecoderWrapper, BPE); impl_enum_from!(ByteLevel, DecoderWrapper, ByteLevel); impl_enum_from!(Metaspace, DecoderWrapper, Metaspace); impl_enum_from!(WordPiece, DecoderWrapper, WordPiece); +impl_enum_from!(CTC, DecoderWrapper, CTC); diff --git a/tokenizers/src/decoders/wordpiece.rs b/tokenizers/src/decoders/wordpiece.rs index 09e0b630..5f3b5258 100644 --- a/tokenizers/src/decoders/wordpiece.rs +++ b/tokenizers/src/decoders/wordpiece.rs @@ -28,23 +28,26 @@ impl Default for WordPiece { } } } +pub fn cleanup(dirty_input: String) -> String { + dirty_input + .replace(" .", ".") + .replace(" ?", "?") + .replace(" !", "!") + .replace(" ,", ",") + .replace(" ' ", "'") + .replace(" n't", "n't") + .replace(" 'm", "'m") + .replace(" do not", " don't") + .replace(" 's", "'s") + .replace(" 've", "'ve") + .replace(" 're", "'re") +} impl Decoder for WordPiece { fn decode(&self, tokens: Vec) -> Result { let mut output = tokens.join(" ").replace(&format!(" {}", self.prefix), ""); if self.cleanup { - output = output - .replace(" .", ".") - .replace(" ?", "?") - .replace(" !", "!") - .replace(" ,", ",") - .replace(" ' ", "'") - .replace(" n't", "n't") - .replace(" 'm", "'m") - .replace(" do not", " don't") - .replace(" 's", "'s") - .replace(" 've", "'ve") - .replace(" 're", "'re"); + output = cleanup(output); } Ok(output) diff --git a/tokenizers/src/pre_tokenizers/byte_level.rs b/tokenizers/src/pre_tokenizers/byte_level.rs index 3e456c5f..794f0d8a 100644 --- a/tokenizers/src/pre_tokenizers/byte_level.rs +++ b/tokenizers/src/pre_tokenizers/byte_level.rs @@ -167,7 +167,7 @@ impl PostProcessor for ByteLevel { } } - PostProcessor::default_process(encoding, pair_encoding, add_special_tokens) + ::default_process(encoding, pair_encoding, add_special_tokens) } } diff --git a/tokenizers/src/processors/bert.rs b/tokenizers/src/processors/bert.rs index 170d8ad2..1027dd7d 100644 --- a/tokenizers/src/processors/bert.rs +++ b/tokenizers/src/processors/bert.rs @@ -41,7 +41,11 @@ impl PostProcessor for BertProcessing { add_special_tokens: bool, ) -> Result { if !add_special_tokens { - return PostProcessor::default_process(encoding, pair_encoding, add_special_tokens); + return ::default_process( + encoding, + pair_encoding, + add_special_tokens, + ); } let ids = [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat(); diff --git a/tokenizers/src/processors/roberta.rs b/tokenizers/src/processors/roberta.rs index 8f915ffc..0036a7e8 100644 --- a/tokenizers/src/processors/roberta.rs +++ b/tokenizers/src/processors/roberta.rs @@ -74,7 +74,11 @@ impl PostProcessor for RobertaProcessing { } if !add_special_tokens { - return PostProcessor::default_process(encoding, pair_encoding, add_special_tokens); + return ::default_process( + encoding, + pair_encoding, + add_special_tokens, + ); } let ids = [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat(); diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index c42c992b..9d4cada4 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -877,7 +877,7 @@ where let final_encoding = if let Some(processor) = &self.post_processor { processor.process(encoding, pair_encoding, add_special_tokens)? } else { - PostProcessor::default_process(encoding, pair_encoding, add_special_tokens)? + ::default_process(encoding, pair_encoding, add_special_tokens)? }; // 3. Then we pad if needed