mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Add CTC Decoder for Wave2Vec models (#693)
* Rust - add a CTCDecoder as a seperate mod * Adding bindings to Node + Python. * Clippy update. * Stub. * Fixing roberta.json URLs. * Moving test files to hf.co. * Update cargo check and clippy to 1.52. * Inner ':' actually is used for domains in sphinx. Making `domain` work correctly was just too much work so I went the easy way and have global roles for the custom rust extension. * Update struct naming and docs * Update changelog Co-authored-by: Thomaub <github.thomaub@gmail.com> Co-authored-by: Anthony MOI <m.anthony.moi@gmail.com>
This commit is contained in:
@ -27,7 +27,7 @@ $(DATA_DIR)/small.txt : $(DATA_DIR)/big.txt
|
|||||||
|
|
||||||
$(DATA_DIR)/roberta.json :
|
$(DATA_DIR)/roberta.json :
|
||||||
$(dir_guard)
|
$(dir_guard)
|
||||||
wget https://storage.googleapis.com/tokenizers/roberta.json -O $@
|
wget https://huggingface.co/roberta-large/raw/main/tokenizer.json -O $@
|
||||||
|
|
||||||
$(DATA_DIR)/tokenizer-wiki.json :
|
$(DATA_DIR)/tokenizer-wiki.json :
|
||||||
$(dir_guard)
|
$(dir_guard)
|
||||||
|
13
bindings/node/lib/bindings/decoders.d.ts
vendored
13
bindings/node/lib/bindings/decoders.d.ts
vendored
@ -36,3 +36,16 @@ export function metaspaceDecoder(replacement?: string, addPrefixSpace?: boolean)
|
|||||||
* This suffix will be replaced by whitespaces during the decoding
|
* This suffix will be replaced by whitespaces during the decoding
|
||||||
*/
|
*/
|
||||||
export function bpeDecoder(suffix?: string): Decoder;
|
export function bpeDecoder(suffix?: string): Decoder;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Instantiate a new CTC Decoder
|
||||||
|
* @param [pad_token='pad'] The pad token used by CTC to delimit a new token.
|
||||||
|
* @param [word_delimiter_token='|'] The word delimiter token. It will be replaced by a space
|
||||||
|
* @param [cleanup=true] Whether to cleanup some tokenization artifacts.
|
||||||
|
* Mainly spaces before punctuation, and some abbreviated english forms.
|
||||||
|
*/
|
||||||
|
export function ctcDecoder(
|
||||||
|
pad_token?: string,
|
||||||
|
word_delimiter_token?: string,
|
||||||
|
cleanup?: boolean
|
||||||
|
): Decoder;
|
||||||
|
@ -5,4 +5,5 @@ module.exports = {
|
|||||||
wordPieceDecoder: native.decoders_WordPiece,
|
wordPieceDecoder: native.decoders_WordPiece,
|
||||||
metaspaceDecoder: native.decoders_Metaspace,
|
metaspaceDecoder: native.decoders_Metaspace,
|
||||||
bpeDecoder: native.decoders_BPEDecoder,
|
bpeDecoder: native.decoders_BPEDecoder,
|
||||||
|
ctcDecoder: native.decoders_CTC,
|
||||||
};
|
};
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import { bpeDecoder, metaspaceDecoder, wordPieceDecoder } from "./decoders";
|
import { bpeDecoder, ctcDecoder, metaspaceDecoder, wordPieceDecoder } from "./decoders";
|
||||||
|
|
||||||
describe("wordPieceDecoder", () => {
|
describe("wordPieceDecoder", () => {
|
||||||
it("accepts `undefined` as first parameter", () => {
|
it("accepts `undefined` as first parameter", () => {
|
||||||
@ -31,3 +31,14 @@ describe("bpeDecoder", () => {
|
|||||||
expect(bpeDecoder(undefined)).toBeDefined();
|
expect(bpeDecoder(undefined)).toBeDefined();
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe("ctcDecoder", () => {
|
||||||
|
it("accepts `undefined` as parameter", () => {
|
||||||
|
expect(ctcDecoder(undefined)).toBeDefined();
|
||||||
|
});
|
||||||
|
it("encodes correctly", () => {
|
||||||
|
expect(
|
||||||
|
ctcDecoder().decode(["<pad>", "h", "h", "e", "e", "l", "l", "<pad>", "l", "l", "o"])
|
||||||
|
).toEqual("hello");
|
||||||
|
});
|
||||||
|
});
|
||||||
|
@ -97,11 +97,30 @@ fn bpe_decoder(mut cx: FunctionContext) -> JsResult<JsDecoder> {
|
|||||||
Ok(decoder)
|
Ok(decoder)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// ctc_decoder(pad_token: String = "<pad>", word_delimiter_token: String = "|", cleanup = true)
|
||||||
|
fn ctc_decoder(mut cx: FunctionContext) -> JsResult<JsDecoder> {
|
||||||
|
let pad_token = cx
|
||||||
|
.extract_opt::<String>(0)?
|
||||||
|
.unwrap_or_else(|| String::from("<pad>"));
|
||||||
|
let word_delimiter_token = cx
|
||||||
|
.extract_opt::<String>(1)?
|
||||||
|
.unwrap_or_else(|| String::from("|"));
|
||||||
|
let cleanup = cx.extract_opt::<bool>(2)?.unwrap_or(true);
|
||||||
|
|
||||||
|
let mut decoder = JsDecoder::new::<_, JsDecoder, _>(&mut cx, vec![])?;
|
||||||
|
let guard = cx.lock();
|
||||||
|
decoder.borrow_mut(&guard).decoder = Some(Arc::new(
|
||||||
|
tk::decoders::ctc::CTC::new(pad_token, word_delimiter_token, cleanup).into(),
|
||||||
|
));
|
||||||
|
Ok(decoder)
|
||||||
|
}
|
||||||
|
|
||||||
/// Register everything here
|
/// Register everything here
|
||||||
pub fn register(m: &mut ModuleContext, prefix: &str) -> NeonResult<()> {
|
pub fn register(m: &mut ModuleContext, prefix: &str) -> NeonResult<()> {
|
||||||
m.export_function(&format!("{}_ByteLevel", prefix), byte_level)?;
|
m.export_function(&format!("{}_ByteLevel", prefix), byte_level)?;
|
||||||
m.export_function(&format!("{}_WordPiece", prefix), wordpiece)?;
|
m.export_function(&format!("{}_WordPiece", prefix), wordpiece)?;
|
||||||
m.export_function(&format!("{}_Metaspace", prefix), metaspace)?;
|
m.export_function(&format!("{}_Metaspace", prefix), metaspace)?;
|
||||||
m.export_function(&format!("{}_BPEDecoder", prefix), bpe_decoder)?;
|
m.export_function(&format!("{}_BPEDecoder", prefix), bpe_decoder)?;
|
||||||
|
m.export_function(&format!("{}_CTC", prefix), ctc_decoder)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -4,6 +4,11 @@ All notable changes to this project will be documented in this file.
|
|||||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||||
|
|
||||||
|
## [Unreleased]
|
||||||
|
|
||||||
|
### Added
|
||||||
|
- [#693]: Add a CTC Decoder for Wave2Vec models
|
||||||
|
|
||||||
## [0.10.2]
|
## [0.10.2]
|
||||||
|
|
||||||
### Fixed
|
### Fixed
|
||||||
@ -309,6 +314,7 @@ delimiter (Works like `.split(delimiter)`)
|
|||||||
- Fix a bug that was causing crashes in Python 3.5
|
- Fix a bug that was causing crashes in Python 3.5
|
||||||
|
|
||||||
|
|
||||||
|
[#694]: https://github.com/huggingface/tokenizers/pull/694
|
||||||
[#674]: https://github.com/huggingface/tokenizers/pull/674
|
[#674]: https://github.com/huggingface/tokenizers/pull/674
|
||||||
[#656]: https://github.com/huggingface/tokenizers/pull/656
|
[#656]: https://github.com/huggingface/tokenizers/pull/656
|
||||||
[#652]: https://github.com/huggingface/tokenizers/pull/652
|
[#652]: https://github.com/huggingface/tokenizers/pull/652
|
||||||
|
@ -30,4 +30,4 @@ $(DATA_DIR)/small.txt : $(DATA_DIR)/big.txt
|
|||||||
|
|
||||||
$(DATA_DIR)/roberta.json :
|
$(DATA_DIR)/roberta.json :
|
||||||
$(dir_guard)
|
$(dir_guard)
|
||||||
wget https://storage.googleapis.com/tokenizers/roberta.json -O $@
|
wget https://huggingface.co/roberta-large/raw/main/tokenizer.json -O $@
|
||||||
|
@ -5,3 +5,4 @@ ByteLevel = decoders.ByteLevel
|
|||||||
WordPiece = decoders.WordPiece
|
WordPiece = decoders.WordPiece
|
||||||
Metaspace = decoders.Metaspace
|
Metaspace = decoders.Metaspace
|
||||||
BPEDecoder = decoders.BPEDecoder
|
BPEDecoder = decoders.BPEDecoder
|
||||||
|
CTC = decoders.CTC
|
||||||
|
@ -68,6 +68,35 @@ class ByteLevel(Decoder):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
class CTC(Decoder):
|
||||||
|
"""
|
||||||
|
CTC Decoder
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pad_token (:obj:`str`, `optional`, defaults to :obj:`<pad>`):
|
||||||
|
The pad token used by CTC to delimit a new token.
|
||||||
|
word_delimiter_token (:obj:`str`, `optional`, defaults to :obj:`|`):
|
||||||
|
The word delimiter token. It will be replaced by a <space>
|
||||||
|
cleanup (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
|
Whether to cleanup some tokenization artifacts.
|
||||||
|
Mainly spaces before punctuation, and some abbreviated english forms.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, pad_token="<pad>", word_delimiter_token="|", cleanup=True):
|
||||||
|
pass
|
||||||
|
def decode(self, tokens):
|
||||||
|
"""
|
||||||
|
Decode the given list of tokens to a final string
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokens (:obj:`List[str]`):
|
||||||
|
The list of tokens to decode
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:obj:`str`: The decoded string
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
class Metaspace(Decoder):
|
class Metaspace(Decoder):
|
||||||
"""
|
"""
|
||||||
Metaspace Decoder
|
Metaspace Decoder
|
||||||
|
@ -8,6 +8,7 @@ use serde::de::Error;
|
|||||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||||
use tk::decoders::bpe::BPEDecoder;
|
use tk::decoders::bpe::BPEDecoder;
|
||||||
use tk::decoders::byte_level::ByteLevel;
|
use tk::decoders::byte_level::ByteLevel;
|
||||||
|
use tk::decoders::ctc::CTC;
|
||||||
use tk::decoders::metaspace::Metaspace;
|
use tk::decoders::metaspace::Metaspace;
|
||||||
use tk::decoders::wordpiece::WordPiece;
|
use tk::decoders::wordpiece::WordPiece;
|
||||||
use tk::decoders::DecoderWrapper;
|
use tk::decoders::DecoderWrapper;
|
||||||
@ -43,6 +44,7 @@ impl PyDecoder {
|
|||||||
DecoderWrapper::WordPiece(_) => Py::new(py, (PyWordPieceDec {}, base))?.into_py(py),
|
DecoderWrapper::WordPiece(_) => Py::new(py, (PyWordPieceDec {}, base))?.into_py(py),
|
||||||
DecoderWrapper::ByteLevel(_) => Py::new(py, (PyByteLevelDec {}, base))?.into_py(py),
|
DecoderWrapper::ByteLevel(_) => Py::new(py, (PyByteLevelDec {}, base))?.into_py(py),
|
||||||
DecoderWrapper::BPE(_) => Py::new(py, (PyBPEDecoder {}, base))?.into_py(py),
|
DecoderWrapper::BPE(_) => Py::new(py, (PyBPEDecoder {}, base))?.into_py(py),
|
||||||
|
DecoderWrapper::CTC(_) => Py::new(py, (PyCTCDecoder {}, base))?.into_py(py),
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -264,6 +266,65 @@ impl PyBPEDecoder {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// CTC Decoder
|
||||||
|
///
|
||||||
|
/// Args:
|
||||||
|
/// pad_token (:obj:`str`, `optional`, defaults to :obj:`<pad>`):
|
||||||
|
/// The pad token used by CTC to delimit a new token.
|
||||||
|
/// word_delimiter_token (:obj:`str`, `optional`, defaults to :obj:`|`):
|
||||||
|
/// The word delimiter token. It will be replaced by a <space>
|
||||||
|
/// cleanup (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
|
/// Whether to cleanup some tokenization artifacts.
|
||||||
|
/// Mainly spaces before punctuation, and some abbreviated english forms.
|
||||||
|
#[pyclass(extends=PyDecoder, module = "tokenizers.decoders", name=CTC)]
|
||||||
|
#[text_signature = "(self, pad_token=\"<pad>\", word_delimiter_token=\"|\", cleanup=True)"]
|
||||||
|
pub struct PyCTCDecoder {}
|
||||||
|
#[pymethods]
|
||||||
|
impl PyCTCDecoder {
|
||||||
|
#[getter]
|
||||||
|
fn get_pad_token(self_: PyRef<Self>) -> String {
|
||||||
|
getter!(self_, CTC, pad_token.clone())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[setter]
|
||||||
|
fn set_pad_token(self_: PyRef<Self>, pad_token: String) {
|
||||||
|
setter!(self_, CTC, pad_token, pad_token);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[getter]
|
||||||
|
fn get_word_delimiter_token(self_: PyRef<Self>) -> String {
|
||||||
|
getter!(self_, CTC, word_delimiter_token.clone())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[setter]
|
||||||
|
fn set_word_delimiter_token(self_: PyRef<Self>, word_delimiter_token: String) {
|
||||||
|
setter!(self_, CTC, word_delimiter_token, word_delimiter_token);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[getter]
|
||||||
|
fn get_cleanup(self_: PyRef<Self>) -> bool {
|
||||||
|
getter!(self_, CTC, cleanup)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[setter]
|
||||||
|
fn set_cleanup(self_: PyRef<Self>, cleanup: bool) {
|
||||||
|
setter!(self_, CTC, cleanup, cleanup);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[new]
|
||||||
|
#[args(
|
||||||
|
pad_token = "String::from(\"<pad>\")",
|
||||||
|
word_delimiter_token = "String::from(\"|\")",
|
||||||
|
cleanup = "true"
|
||||||
|
)]
|
||||||
|
fn new(pad_token: String, word_delimiter_token: String, cleanup: bool) -> (Self, PyDecoder) {
|
||||||
|
(
|
||||||
|
PyCTCDecoder {},
|
||||||
|
CTC::new(pad_token, word_delimiter_token, cleanup).into(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub(crate) struct CustomDecoder {
|
pub(crate) struct CustomDecoder {
|
||||||
inner: PyObject,
|
inner: PyObject,
|
||||||
|
@ -87,6 +87,7 @@ fn decoders(_py: Python, m: &PyModule) -> PyResult<()> {
|
|||||||
m.add_class::<decoders::PyWordPieceDec>()?;
|
m.add_class::<decoders::PyWordPieceDec>()?;
|
||||||
m.add_class::<decoders::PyMetaspaceDec>()?;
|
m.add_class::<decoders::PyMetaspaceDec>()?;
|
||||||
m.add_class::<decoders::PyBPEDecoder>()?;
|
m.add_class::<decoders::PyBPEDecoder>()?;
|
||||||
|
m.add_class::<decoders::PyCTCDecoder>()?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -15,7 +15,7 @@ where
|
|||||||
I: Iterator,
|
I: Iterator,
|
||||||
{
|
{
|
||||||
pub fn new(iter: I, length: Option<usize>) -> Self {
|
pub fn new(iter: I, length: Option<usize>) -> Self {
|
||||||
Self { iter, length }
|
Self { length, iter }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2,7 +2,7 @@ import pytest
|
|||||||
import pickle
|
import pickle
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from tokenizers.decoders import Decoder, ByteLevel, WordPiece, Metaspace, BPEDecoder
|
from tokenizers.decoders import Decoder, ByteLevel, WordPiece, Metaspace, BPEDecoder, CTC
|
||||||
|
|
||||||
|
|
||||||
class TestByteLevel:
|
class TestByteLevel:
|
||||||
@ -108,3 +108,45 @@ class TestBPEDecoder:
|
|||||||
# Modify these
|
# Modify these
|
||||||
decoder.suffix = "</w>"
|
decoder.suffix = "</w>"
|
||||||
assert decoder.suffix == "</w>"
|
assert decoder.suffix == "</w>"
|
||||||
|
|
||||||
|
|
||||||
|
class TestCTCDecoder:
|
||||||
|
def test_instantiate(self):
|
||||||
|
assert CTC() is not None
|
||||||
|
assert CTC(pad_token="[PAD]") is not None
|
||||||
|
assert isinstance(CTC(), Decoder)
|
||||||
|
assert isinstance(CTC(), CTC)
|
||||||
|
assert isinstance(pickle.loads(pickle.dumps(CTC())), CTC)
|
||||||
|
|
||||||
|
def test_decoding(self):
|
||||||
|
decoder = CTC()
|
||||||
|
assert (
|
||||||
|
decoder.decode(
|
||||||
|
["<pad>", "<pad>", "h", "e", "e", "l", "l", "<pad>", "l", "o", "o", "o", "<pad>"]
|
||||||
|
)
|
||||||
|
== "hello"
|
||||||
|
)
|
||||||
|
decoder = CTC(pad_token="[PAD]")
|
||||||
|
assert (
|
||||||
|
decoder.decode(
|
||||||
|
["[PAD]", "[PAD]", "h", "e", "e", "l", "l", "[PAD]", "l", "o", "o", "o", "[PAD]"]
|
||||||
|
)
|
||||||
|
== "hello"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_can_modify(self):
|
||||||
|
decoder = CTC(pad_token="[PAD]")
|
||||||
|
|
||||||
|
assert decoder.pad_token == "[PAD]"
|
||||||
|
assert decoder.word_delimiter_token == "|"
|
||||||
|
assert decoder.cleanup == True
|
||||||
|
|
||||||
|
# Modify these
|
||||||
|
decoder.pad_token = "{pad}"
|
||||||
|
assert decoder.pad_token == "{pad}"
|
||||||
|
|
||||||
|
decoder.word_delimiter_token = "_"
|
||||||
|
assert decoder.word_delimiter_token == "_"
|
||||||
|
|
||||||
|
decoder.cleanup = False
|
||||||
|
assert decoder.cleanup == False
|
||||||
|
@ -10,7 +10,7 @@ logger = sphinx.util.logging.getLogger(__name__)
|
|||||||
|
|
||||||
class RustRef:
|
class RustRef:
|
||||||
def __call__(self, name, rawtext, text, lineno, inliner, options={}, content=[]):
|
def __call__(self, name, rawtext, text, lineno, inliner, options={}, content=[]):
|
||||||
doctype = name.split(":")[1]
|
doctype = name.split("_")[1]
|
||||||
parts = text.split("::")
|
parts = text.split("::")
|
||||||
|
|
||||||
if text.startswith("~"):
|
if text.startswith("~"):
|
||||||
@ -87,10 +87,10 @@ class RustRef:
|
|||||||
|
|
||||||
|
|
||||||
def setup(app):
|
def setup(app):
|
||||||
app.add_role("rust:struct", RustRef())
|
app.add_role("rust_struct", RustRef())
|
||||||
app.add_role("rust:func", RustRef())
|
app.add_role("rust_func", RustRef())
|
||||||
app.add_role("rust:meth", RustRef())
|
app.add_role("rust_meth", RustRef())
|
||||||
app.add_role("rust:trait", RustRef())
|
app.add_role("rust_trait", RustRef())
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"version": "0.1",
|
"version": "0.1",
|
||||||
|
@ -58,47 +58,47 @@
|
|||||||
classmethod
|
classmethod
|
||||||
static method
|
static method
|
||||||
Tokenizer
|
Tokenizer
|
||||||
:rust:struct:`~tokenizers::tokenizer::Tokenizer`
|
:rust_struct:`~tokenizers::tokenizer::Tokenizer`
|
||||||
Tokenizer.train
|
Tokenizer.train
|
||||||
:rust:meth:`~tokenizers::tokenizer::Tokenizer::train`
|
:rust_meth:`~tokenizers::tokenizer::Tokenizer::train`
|
||||||
Tokenizer.save
|
Tokenizer.save
|
||||||
:rust:meth:`~tokenizers::tokenizer::Tokenizer::save`
|
:rust_meth:`~tokenizers::tokenizer::Tokenizer::save`
|
||||||
Tokenizer.from_file
|
Tokenizer.from_file
|
||||||
:rust:meth:`~tokenizers::tokenizer::Tokenizer::from_file`
|
:rust_meth:`~tokenizers::tokenizer::Tokenizer::from_file`
|
||||||
Tokenizer.encode
|
Tokenizer.encode
|
||||||
:rust:meth:`~tokenizers::tokenizer::Tokenizer::encode`
|
:rust_meth:`~tokenizers::tokenizer::Tokenizer::encode`
|
||||||
Tokenizer.encode_batch
|
Tokenizer.encode_batch
|
||||||
:rust:meth:`~tokenizers::tokenizer::Tokenizer::encode_batch`
|
:rust_meth:`~tokenizers::tokenizer::Tokenizer::encode_batch`
|
||||||
Tokenizer.decode
|
Tokenizer.decode
|
||||||
:rust:meth:`~tokenizers::tokenizer::Tokenizer::decode`
|
:rust_meth:`~tokenizers::tokenizer::Tokenizer::decode`
|
||||||
Tokenizer.decode_batch
|
Tokenizer.decode_batch
|
||||||
:rust:meth:`~tokenizers::tokenizer::Tokenizer::decode_batch`
|
:rust_meth:`~tokenizers::tokenizer::Tokenizer::decode_batch`
|
||||||
Tokenizer.token_to_id
|
Tokenizer.token_to_id
|
||||||
:rust:meth:`~tokenizers::tokenizer::Tokenizer::token_to_id`
|
:rust_meth:`~tokenizers::tokenizer::Tokenizer::token_to_id`
|
||||||
Tokenizer.enable_padding
|
Tokenizer.enable_padding
|
||||||
:rust:meth:`~tokenizers::tokenizer::Tokenizer::enable_padding`
|
:rust_meth:`~tokenizers::tokenizer::Tokenizer::enable_padding`
|
||||||
Encoding
|
Encoding
|
||||||
:rust:struct:`~tokenizers::tokenizer::Encoding`
|
:rust_struct:`~tokenizers::tokenizer::Encoding`
|
||||||
TemplateProcessing
|
TemplateProcessing
|
||||||
:rust:struct:`~tokenizers::processors::template::TemplateProcessing`
|
:rust_struct:`~tokenizers::processors::template::TemplateProcessing`
|
||||||
Normalizer
|
Normalizer
|
||||||
:rust:trait:`~tokenizers::tokenizer::Normalizer`
|
:rust_trait:`~tokenizers::tokenizer::Normalizer`
|
||||||
normalizers.Sequence
|
normalizers.Sequence
|
||||||
:rust:struct:`~tokenizers::normalizers::utils::Sequence`
|
:rust_struct:`~tokenizers::normalizers::utils::Sequence`
|
||||||
pre_tokenizers.Whitespace
|
pre_tokenizers.Whitespace
|
||||||
:rust:struct:`~tokenizers::normalizers::whitespace::Whitespace`
|
:rust_struct:`~tokenizers::normalizers::whitespace::Whitespace`
|
||||||
PreTokenizer
|
PreTokenizer
|
||||||
:rust:trait:`~tokenizers::tokenizer::PreTokenizer`
|
:rust_trait:`~tokenizers::tokenizer::PreTokenizer`
|
||||||
models.BPE
|
models.BPE
|
||||||
:rust:struct:`~tokenizers::models::bpe::BPE`
|
:rust_struct:`~tokenizers::models::bpe::BPE`
|
||||||
models.Unigram
|
models.Unigram
|
||||||
:rust:struct:`~tokenizers::models::unigram::Unigram`
|
:rust_struct:`~tokenizers::models::unigram::Unigram`
|
||||||
models.WordLevel
|
models.WordLevel
|
||||||
:rust:struct:`~tokenizers::models::wordlevel::WordLevel`
|
:rust_struct:`~tokenizers::models::wordlevel::WordLevel`
|
||||||
models.WordPiece
|
models.WordPiece
|
||||||
:rust:struct:`~tokenizers::models::wordpiece::WordPiece`
|
:rust_struct:`~tokenizers::models::wordpiece::WordPiece`
|
||||||
Decoder
|
Decoder
|
||||||
:rust:trait:`~tokenizers::tokenizer::Decoder`
|
:rust_trait:`~tokenizers::tokenizer::Decoder`
|
||||||
|
|
||||||
.. entities:: node
|
.. entities:: node
|
||||||
|
|
||||||
|
@ -44,7 +44,7 @@ Training the tokenizer
|
|||||||
.. entities:: rust
|
.. entities:: rust
|
||||||
|
|
||||||
BpeTrainer
|
BpeTrainer
|
||||||
:rust:struct:`~tokenizers::models::bpe::BpeTrainer`
|
:rust_struct:`~tokenizers::models::bpe::BpeTrainer`
|
||||||
vocab_size
|
vocab_size
|
||||||
:obj:`vocab_size`
|
:obj:`vocab_size`
|
||||||
min_frequency
|
min_frequency
|
||||||
|
@ -55,7 +55,7 @@ $(DATA_DIR)/bert-% :
|
|||||||
|
|
||||||
$(DATA_DIR)/unigram% :
|
$(DATA_DIR)/unigram% :
|
||||||
$(dir_guard)
|
$(dir_guard)
|
||||||
wget https://storage.googleapis.com/tokenizers/unigram$* -O $@
|
wget https://huggingface.co/Narsil/small/raw/main/unigram$* -O $@
|
||||||
|
|
||||||
$(DATA_DIR)/albert-base-v1-tokenizer.json :
|
$(DATA_DIR)/albert-base-v1-tokenizer.json :
|
||||||
$(dir_guard)
|
$(dir_guard)
|
||||||
@ -70,7 +70,7 @@ $(DATA_DIR)/small.txt : $(DATA_DIR)/big.txt
|
|||||||
|
|
||||||
$(DATA_DIR)/roberta.json :
|
$(DATA_DIR)/roberta.json :
|
||||||
$(dir_guard)
|
$(dir_guard)
|
||||||
wget https://storage.googleapis.com/tokenizers/roberta.json -O $@
|
wget https://huggingface.co/Narsil/small/raw/main/roberta.json -O $@
|
||||||
|
|
||||||
$(DATA_DIR)/tokenizer-wiki.json :
|
$(DATA_DIR)/tokenizer-wiki.json :
|
||||||
$(dir_guard)
|
$(dir_guard)
|
||||||
|
103
tokenizers/src/decoders/ctc.rs
Normal file
103
tokenizers/src/decoders/ctc.rs
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
use crate::decoders::wordpiece;
|
||||||
|
use crate::tokenizer::{Decoder, Result};
|
||||||
|
|
||||||
|
use itertools::Itertools;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
/// The CTC (Connectionist Temporal Classification) decoder takes care
|
||||||
|
/// of sanitizing a list of inputs token.
|
||||||
|
/// Due to some alignement problem the output of some models can come
|
||||||
|
/// with duplicated token.
|
||||||
|
#[serde(tag = "type")]
|
||||||
|
#[non_exhaustive]
|
||||||
|
pub struct CTC {
|
||||||
|
/// The pad token used by CTC to delimit a new token.
|
||||||
|
pub pad_token: String,
|
||||||
|
/// The word delimiter token. It will be replaced by a <space>
|
||||||
|
pub word_delimiter_token: String,
|
||||||
|
/// Whether to cleanup some tokenization artifacts.
|
||||||
|
/// Mainly spaces before punctuation, and some abbreviated english forms.
|
||||||
|
pub cleanup: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CTC {
|
||||||
|
pub fn new(pad_token: String, word_delimiter_token: String, cleanup: bool) -> Self {
|
||||||
|
Self {
|
||||||
|
pad_token,
|
||||||
|
word_delimiter_token,
|
||||||
|
cleanup,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for CTC {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
pad_token: "<pad>".to_string(),
|
||||||
|
word_delimiter_token: "|".to_string(),
|
||||||
|
cleanup: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Decoder for CTC {
|
||||||
|
fn decode(&self, tokens: Vec<String>) -> Result<String> {
|
||||||
|
let mut output = tokens
|
||||||
|
.into_iter()
|
||||||
|
.dedup()
|
||||||
|
.join("")
|
||||||
|
.replace(&self.pad_token, "");
|
||||||
|
if self.cleanup {
|
||||||
|
output = wordpiece::cleanup(output).replace(&self.word_delimiter_token, " ");
|
||||||
|
}
|
||||||
|
Ok(output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
#[test]
|
||||||
|
fn handmade_sample() {
|
||||||
|
let ctc_decoder = CTC::default();
|
||||||
|
let id_to_string_result = "<pad> <pad> h e e l l <pad> l o o o <pad>"
|
||||||
|
.split(' ')
|
||||||
|
.map(|s| s.to_string())
|
||||||
|
.collect();
|
||||||
|
assert_eq!(
|
||||||
|
ctc_decoder.decode(id_to_string_result).unwrap(),
|
||||||
|
"hello".to_string()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
#[test]
|
||||||
|
fn handmade_with_delimiter_sample() {
|
||||||
|
let ctc_decoder = CTC::default();
|
||||||
|
let id_to_string_result = "<pad> <pad> h e e l l <pad> l o o o <pad> <pad> | <pad> w o o o r <pad> <pad> l l d <pad> <pad> <pad> <pad>"
|
||||||
|
.split(' ')
|
||||||
|
.map(|s| s.to_string())
|
||||||
|
.collect();
|
||||||
|
assert_eq!(
|
||||||
|
ctc_decoder.decode(id_to_string_result).unwrap(),
|
||||||
|
"hello world".to_string()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
#[test]
|
||||||
|
fn librispeech_sample() {
|
||||||
|
let ctc_decoder = CTC::default();
|
||||||
|
let id_to_string_result = "<pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> A | | <pad> M <pad> <pad> <pad> <pad> A <pad> <pad> N <pad> <pad> <pad> | | | <pad> <pad> <pad> <pad> S <pad> <pad> <pad> A I <pad> D D | | T T <pad> O <pad> | | T H E E | | | <pad> U U <pad> N N <pad> I <pad> <pad> V <pad> <pad> <pad> E R R <pad> <pad> <pad> S E E | | <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> S S <pad> <pad> <pad> <pad> I <pad> R R <pad> <pad> | | | <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> I <pad> <pad> <pad> | <pad> <pad> <pad> E X <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> I <pad> S <pad> <pad> T <pad> <pad> | | <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>".split(' ').map(|s| s.to_string()).collect();
|
||||||
|
assert_eq!(
|
||||||
|
ctc_decoder.decode(id_to_string_result).unwrap(),
|
||||||
|
"A MAN SAID TO THE UNIVERSE SIR I EXIST ".to_string()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
#[test]
|
||||||
|
fn another_librispeech_sample() {
|
||||||
|
let ctc_decoder = CTC::default();
|
||||||
|
let id_to_string_result = "<pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> H <pad> I <pad> S S | | <pad> <pad> <pad> I N <pad> <pad> S <pad> T T <pad> <pad> A N C C T <pad> | | | | | <pad> <pad> <pad> <pad> P <pad> <pad> <pad> <pad> A <pad> <pad> N N N <pad> <pad> I <pad> C <pad> <pad> | | <pad> W <pad> <pad> A S <pad> | | <pad> <pad> <pad> F <pad> <pad> O L <pad> <pad> L L O O W E E D | | <pad> B <pad> <pad> <pad> Y <pad> | | | A | | <pad> S S S <pad> M M <pad> <pad> <pad> A L L <pad> <pad> <pad> <pad> L <pad> | | | <pad> <pad> <pad> <pad> S H H <pad> <pad> <pad> <pad> A R R <pad> <pad> P <pad> <pad> | <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> B <pad> <pad> L L <pad> <pad> <pad> <pad> <pad> O W W <pad> <pad> | | | <pad> <pad> <pad> <pad> <pad> <pad> <pad> H <pad> <pad> <pad> <pad> <pad> <pad> <pad> I G H H | | <pad> <pad> O N <pad> | | H <pad> I S S | | <pad> <pad> C H H <pad> <pad> <pad> E <pad> S S <pad> T T <pad> <pad> | | | <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>".split(' ').map(|s| s.to_string()).collect();
|
||||||
|
assert_eq!(
|
||||||
|
ctc_decoder.decode(id_to_string_result).unwrap(),
|
||||||
|
"HIS INSTANCT PANIC WAS FOLLOWED BY A SMALL SHARP BLOW HIGH ON HIS CHEST ".to_string()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
@ -1,4 +1,5 @@
|
|||||||
pub mod bpe;
|
pub mod bpe;
|
||||||
|
pub mod ctc;
|
||||||
pub mod wordpiece;
|
pub mod wordpiece;
|
||||||
|
|
||||||
// Re-export these as decoders
|
// Re-export these as decoders
|
||||||
@ -8,6 +9,7 @@ pub use super::pre_tokenizers::metaspace;
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::decoders::bpe::BPEDecoder;
|
use crate::decoders::bpe::BPEDecoder;
|
||||||
|
use crate::decoders::ctc::CTC;
|
||||||
use crate::decoders::wordpiece::WordPiece;
|
use crate::decoders::wordpiece::WordPiece;
|
||||||
use crate::pre_tokenizers::byte_level::ByteLevel;
|
use crate::pre_tokenizers::byte_level::ByteLevel;
|
||||||
use crate::pre_tokenizers::metaspace::Metaspace;
|
use crate::pre_tokenizers::metaspace::Metaspace;
|
||||||
@ -20,6 +22,7 @@ pub enum DecoderWrapper {
|
|||||||
ByteLevel(ByteLevel),
|
ByteLevel(ByteLevel),
|
||||||
WordPiece(WordPiece),
|
WordPiece(WordPiece),
|
||||||
Metaspace(Metaspace),
|
Metaspace(Metaspace),
|
||||||
|
CTC(CTC),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Decoder for DecoderWrapper {
|
impl Decoder for DecoderWrapper {
|
||||||
@ -29,6 +32,7 @@ impl Decoder for DecoderWrapper {
|
|||||||
DecoderWrapper::ByteLevel(bl) => bl.decode(tokens),
|
DecoderWrapper::ByteLevel(bl) => bl.decode(tokens),
|
||||||
DecoderWrapper::Metaspace(ms) => ms.decode(tokens),
|
DecoderWrapper::Metaspace(ms) => ms.decode(tokens),
|
||||||
DecoderWrapper::WordPiece(wp) => wp.decode(tokens),
|
DecoderWrapper::WordPiece(wp) => wp.decode(tokens),
|
||||||
|
DecoderWrapper::CTC(ctc) => ctc.decode(tokens),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -37,3 +41,4 @@ impl_enum_from!(BPEDecoder, DecoderWrapper, BPE);
|
|||||||
impl_enum_from!(ByteLevel, DecoderWrapper, ByteLevel);
|
impl_enum_from!(ByteLevel, DecoderWrapper, ByteLevel);
|
||||||
impl_enum_from!(Metaspace, DecoderWrapper, Metaspace);
|
impl_enum_from!(Metaspace, DecoderWrapper, Metaspace);
|
||||||
impl_enum_from!(WordPiece, DecoderWrapper, WordPiece);
|
impl_enum_from!(WordPiece, DecoderWrapper, WordPiece);
|
||||||
|
impl_enum_from!(CTC, DecoderWrapper, CTC);
|
||||||
|
@ -28,23 +28,26 @@ impl Default for WordPiece {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
pub fn cleanup(dirty_input: String) -> String {
|
||||||
|
dirty_input
|
||||||
|
.replace(" .", ".")
|
||||||
|
.replace(" ?", "?")
|
||||||
|
.replace(" !", "!")
|
||||||
|
.replace(" ,", ",")
|
||||||
|
.replace(" ' ", "'")
|
||||||
|
.replace(" n't", "n't")
|
||||||
|
.replace(" 'm", "'m")
|
||||||
|
.replace(" do not", " don't")
|
||||||
|
.replace(" 's", "'s")
|
||||||
|
.replace(" 've", "'ve")
|
||||||
|
.replace(" 're", "'re")
|
||||||
|
}
|
||||||
|
|
||||||
impl Decoder for WordPiece {
|
impl Decoder for WordPiece {
|
||||||
fn decode(&self, tokens: Vec<String>) -> Result<String> {
|
fn decode(&self, tokens: Vec<String>) -> Result<String> {
|
||||||
let mut output = tokens.join(" ").replace(&format!(" {}", self.prefix), "");
|
let mut output = tokens.join(" ").replace(&format!(" {}", self.prefix), "");
|
||||||
if self.cleanup {
|
if self.cleanup {
|
||||||
output = output
|
output = cleanup(output);
|
||||||
.replace(" .", ".")
|
|
||||||
.replace(" ?", "?")
|
|
||||||
.replace(" !", "!")
|
|
||||||
.replace(" ,", ",")
|
|
||||||
.replace(" ' ", "'")
|
|
||||||
.replace(" n't", "n't")
|
|
||||||
.replace(" 'm", "'m")
|
|
||||||
.replace(" do not", " don't")
|
|
||||||
.replace(" 's", "'s")
|
|
||||||
.replace(" 've", "'ve")
|
|
||||||
.replace(" 're", "'re");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(output)
|
Ok(output)
|
||||||
|
@ -167,7 +167,7 @@ impl PostProcessor for ByteLevel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
PostProcessor::default_process(encoding, pair_encoding, add_special_tokens)
|
<dyn PostProcessor>::default_process(encoding, pair_encoding, add_special_tokens)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -41,7 +41,11 @@ impl PostProcessor for BertProcessing {
|
|||||||
add_special_tokens: bool,
|
add_special_tokens: bool,
|
||||||
) -> Result<Encoding> {
|
) -> Result<Encoding> {
|
||||||
if !add_special_tokens {
|
if !add_special_tokens {
|
||||||
return PostProcessor::default_process(encoding, pair_encoding, add_special_tokens);
|
return <dyn PostProcessor>::default_process(
|
||||||
|
encoding,
|
||||||
|
pair_encoding,
|
||||||
|
add_special_tokens,
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
let ids = [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat();
|
let ids = [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat();
|
||||||
|
@ -74,7 +74,11 @@ impl PostProcessor for RobertaProcessing {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !add_special_tokens {
|
if !add_special_tokens {
|
||||||
return PostProcessor::default_process(encoding, pair_encoding, add_special_tokens);
|
return <dyn PostProcessor>::default_process(
|
||||||
|
encoding,
|
||||||
|
pair_encoding,
|
||||||
|
add_special_tokens,
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
let ids = [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat();
|
let ids = [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat();
|
||||||
|
@ -877,7 +877,7 @@ where
|
|||||||
let final_encoding = if let Some(processor) = &self.post_processor {
|
let final_encoding = if let Some(processor) = &self.post_processor {
|
||||||
processor.process(encoding, pair_encoding, add_special_tokens)?
|
processor.process(encoding, pair_encoding, add_special_tokens)?
|
||||||
} else {
|
} else {
|
||||||
PostProcessor::default_process(encoding, pair_encoding, add_special_tokens)?
|
<dyn PostProcessor>::default_process(encoding, pair_encoding, add_special_tokens)?
|
||||||
};
|
};
|
||||||
|
|
||||||
// 3. Then we pad if needed
|
// 3. Then we pad if needed
|
||||||
|
Reference in New Issue
Block a user