mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Adding 2 new decoders: (#1196)
* Adding 2 new decoders: - Fuse will simply concatenate all tokens into 1 string - Strip will remove n char from left or right Sequence(Replace("_", " "), Fuse(), Strip(1, 0)) should be what we want for the `Metaspace` thing. - Note: Added a new dependency from better parsing of decoders. This is due to untagged enums which can match anything the `MustBe` ensure there's no issue between Fuse and ByteFallback. Since both are new the chances for backward incompatibility is low. * Fixing picking/unpickling (using default args.). * Stub. * Black. * Fixing node.
This commit is contained in:
12
bindings/node/lib/bindings/decoders.d.ts
vendored
12
bindings/node/lib/bindings/decoders.d.ts
vendored
@ -35,6 +35,18 @@ export function wordPieceDecoder(prefix?: string, cleanup?: boolean): Decoder;
|
||||
*/
|
||||
export function byteFallbackDecoder(): Decoder;
|
||||
|
||||
/**
|
||||
* Instantiate a new Fuse Decoder which fuses all tokens into one string
|
||||
*/
|
||||
export function fuseDecoder(): Decoder;
|
||||
|
||||
/**
|
||||
* Instantiate a new Strip Decoder
|
||||
* @param [left] The number of chars to remove from the left of each token
|
||||
* @param [right] The number of chars to remove from the right of each token
|
||||
*/
|
||||
export function stripDecoder(left: number, right: number): Decoder;
|
||||
|
||||
/**
|
||||
* Instantiate a new Metaspace
|
||||
*
|
||||
|
@ -5,6 +5,8 @@ module.exports = {
|
||||
replaceDecoder: native.decoders_Replace,
|
||||
wordPieceDecoder: native.decoders_WordPiece,
|
||||
byteFallbackDecoder: native.decoders_ByteFallback,
|
||||
fuseDecoder: native.decoders_Fuse,
|
||||
stripDecoder: native.decoders_Strip,
|
||||
metaspaceDecoder: native.decoders_Metaspace,
|
||||
bpeDecoder: native.decoders_BPEDecoder,
|
||||
ctcDecoder: native.decoders_CTC,
|
||||
|
@ -2,9 +2,11 @@ import {
|
||||
bpeDecoder,
|
||||
byteFallbackDecoder,
|
||||
ctcDecoder,
|
||||
fuseDecoder,
|
||||
metaspaceDecoder,
|
||||
replaceDecoder,
|
||||
sequenceDecoder,
|
||||
stripDecoder,
|
||||
wordPieceDecoder,
|
||||
} from "./decoders";
|
||||
|
||||
@ -51,6 +53,26 @@ describe("replaceDecoder", () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe("fuseDecoder", () => {
|
||||
it("accepts `undefined` as first parameter", () => {
|
||||
expect(fuseDecoder()).toBeDefined();
|
||||
});
|
||||
|
||||
it("can decode arrays of strings", () => {
|
||||
expect(fuseDecoder().decode(["Hel", "lo"])).toEqual("Hello");
|
||||
});
|
||||
});
|
||||
|
||||
describe("stripDecoder", () => {
|
||||
it("accepts `undefined` as first parameter", () => {
|
||||
expect(stripDecoder(0, 0)).toBeDefined();
|
||||
});
|
||||
|
||||
it("can decode arrays of strings", () => {
|
||||
expect(stripDecoder(1, 0).decode(["Hel", "lo"])).toEqual("elo");
|
||||
});
|
||||
});
|
||||
|
||||
describe("metaspaceDecoder", () => {
|
||||
it("accepts `undefined` as first parameter", () => {
|
||||
expect(metaspaceDecoder(undefined)).toBeDefined();
|
||||
|
@ -96,6 +96,26 @@ fn byte_fallback(mut cx: FunctionContext) -> JsResult<JsDecoder> {
|
||||
Ok(decoder)
|
||||
}
|
||||
|
||||
/// fuse()
|
||||
fn fuse(mut cx: FunctionContext) -> JsResult<JsDecoder> {
|
||||
let mut decoder = JsDecoder::new::<_, JsDecoder, _>(&mut cx, vec![])?;
|
||||
let guard = cx.lock();
|
||||
decoder.borrow_mut(&guard).decoder = Some(Arc::new(tk::decoders::fuse::Fuse::new().into()));
|
||||
Ok(decoder)
|
||||
}
|
||||
|
||||
/// strip()
|
||||
fn strip(mut cx: FunctionContext) -> JsResult<JsDecoder> {
|
||||
let left: usize = cx.extract(0)?;
|
||||
let right: usize = cx.extract(1)?;
|
||||
let mut decoder = JsDecoder::new::<_, JsDecoder, _>(&mut cx, vec![])?;
|
||||
let guard = cx.lock();
|
||||
decoder.borrow_mut(&guard).decoder = Some(Arc::new(
|
||||
tk::decoders::strip::Strip::new(left, right).into(),
|
||||
));
|
||||
Ok(decoder)
|
||||
}
|
||||
|
||||
/// metaspace(replacement: String = "_", add_prefix_space: bool = true)
|
||||
fn metaspace(mut cx: FunctionContext) -> JsResult<JsDecoder> {
|
||||
let replacement = cx.extract_opt::<char>(0)?.unwrap_or('▁');
|
||||
@ -173,6 +193,8 @@ pub fn register(m: &mut ModuleContext, prefix: &str) -> NeonResult<()> {
|
||||
m.export_function(&format!("{}_Replace", prefix), replace)?;
|
||||
m.export_function(&format!("{}_WordPiece", prefix), wordpiece)?;
|
||||
m.export_function(&format!("{}_ByteFallback", prefix), byte_fallback)?;
|
||||
m.export_function(&format!("{}_Fuse", prefix), fuse)?;
|
||||
m.export_function(&format!("{}_Strip", prefix), strip)?;
|
||||
m.export_function(&format!("{}_Metaspace", prefix), metaspace)?;
|
||||
m.export_function(&format!("{}_BPEDecoder", prefix), bpe_decoder)?;
|
||||
m.export_function(&format!("{}_CTC", prefix), ctc_decoder)?;
|
||||
|
@ -6,6 +6,8 @@ ByteLevel = decoders.ByteLevel
|
||||
Replace = decoders.Replace
|
||||
WordPiece = decoders.WordPiece
|
||||
ByteFallback = decoders.ByteFallback
|
||||
Fuse = decoders.Fuse
|
||||
Strip = decoders.Strip
|
||||
Metaspace = decoders.Metaspace
|
||||
BPEDecoder = decoders.BPEDecoder
|
||||
CTC = decoders.CTC
|
||||
|
@ -121,6 +121,29 @@ class CTC(Decoder):
|
||||
"""
|
||||
pass
|
||||
|
||||
class Fuse(Decoder):
|
||||
"""
|
||||
Fuse Decoder
|
||||
Fuse simply fuses every token into a single string.
|
||||
This is the last step of decoding, this decoder exists only if
|
||||
there is need to add other decoders *after* the fusion
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
def decode(self, tokens):
|
||||
"""
|
||||
Decode the given list of tokens to a final string
|
||||
|
||||
Args:
|
||||
tokens (:obj:`List[str]`):
|
||||
The list of tokens to decode
|
||||
|
||||
Returns:
|
||||
:obj:`str`: The decoded string
|
||||
"""
|
||||
pass
|
||||
|
||||
class Metaspace(Decoder):
|
||||
"""
|
||||
Metaspace Decoder
|
||||
@ -197,6 +220,27 @@ class Sequence(Decoder):
|
||||
"""
|
||||
pass
|
||||
|
||||
class Strip(Decoder):
|
||||
"""
|
||||
Strip normalizer
|
||||
Strips n left characters of each token, or n right characters of each token
|
||||
"""
|
||||
|
||||
def __init__(self, left=0, right=0):
|
||||
pass
|
||||
def decode(self, tokens):
|
||||
"""
|
||||
Decode the given list of tokens to a final string
|
||||
|
||||
Args:
|
||||
tokens (:obj:`List[str]`):
|
||||
The list of tokens to decode
|
||||
|
||||
Returns:
|
||||
:obj:`str`: The decoded string
|
||||
"""
|
||||
pass
|
||||
|
||||
class WordPiece(Decoder):
|
||||
"""
|
||||
WordPiece Decoder
|
||||
|
@ -11,8 +11,10 @@ use tk::decoders::bpe::BPEDecoder;
|
||||
use tk::decoders::byte_fallback::ByteFallback;
|
||||
use tk::decoders::byte_level::ByteLevel;
|
||||
use tk::decoders::ctc::CTC;
|
||||
use tk::decoders::fuse::Fuse;
|
||||
use tk::decoders::metaspace::Metaspace;
|
||||
use tk::decoders::sequence::Sequence;
|
||||
use tk::decoders::strip::Strip;
|
||||
use tk::decoders::wordpiece::WordPiece;
|
||||
use tk::decoders::DecoderWrapper;
|
||||
use tk::normalizers::replace::Replace;
|
||||
@ -47,6 +49,8 @@ impl PyDecoder {
|
||||
DecoderWrapper::ByteFallback(_) => {
|
||||
Py::new(py, (PyByteFallbackDec {}, base))?.into_py(py)
|
||||
}
|
||||
DecoderWrapper::Strip(_) => Py::new(py, (PyStrip {}, base))?.into_py(py),
|
||||
DecoderWrapper::Fuse(_) => Py::new(py, (PyFuseDec {}, base))?.into_py(py),
|
||||
DecoderWrapper::ByteLevel(_) => Py::new(py, (PyByteLevelDec {}, base))?.into_py(py),
|
||||
DecoderWrapper::Replace(_) => Py::new(py, (PyReplaceDec {}, base))?.into_py(py),
|
||||
DecoderWrapper::BPE(_) => Py::new(py, (PyBPEDecoder {}, base))?.into_py(py),
|
||||
@ -238,6 +242,56 @@ impl PyByteFallbackDec {
|
||||
}
|
||||
}
|
||||
|
||||
/// Fuse Decoder
|
||||
/// Fuse simply fuses every token into a single string.
|
||||
/// This is the last step of decoding, this decoder exists only if
|
||||
/// there is need to add other decoders *after* the fusion
|
||||
#[pyclass(extends=PyDecoder, module = "tokenizers.decoders", name = "Fuse")]
|
||||
#[pyo3(text_signature = "(self)")]
|
||||
pub struct PyFuseDec {}
|
||||
#[pymethods]
|
||||
impl PyFuseDec {
|
||||
#[new]
|
||||
#[pyo3(signature = ())]
|
||||
fn new() -> (Self, PyDecoder) {
|
||||
(PyFuseDec {}, Fuse::new().into())
|
||||
}
|
||||
}
|
||||
|
||||
/// Strip normalizer
|
||||
/// Strips n left characters of each token, or n right characters of each token
|
||||
#[pyclass(extends=PyDecoder, module = "tokenizers.decoders", name = "Strip")]
|
||||
#[pyo3(text_signature = "(self, left=0, right=0)")]
|
||||
pub struct PyStrip {}
|
||||
#[pymethods]
|
||||
impl PyStrip {
|
||||
#[getter]
|
||||
fn get_left(self_: PyRef<Self>) -> usize {
|
||||
getter!(self_, Strip, left)
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_left(self_: PyRef<Self>, left: usize) {
|
||||
setter!(self_, Strip, left, left)
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_right(self_: PyRef<Self>) -> usize {
|
||||
getter!(self_, Strip, right)
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_right(self_: PyRef<Self>, right: usize) {
|
||||
setter!(self_, Strip, right, right)
|
||||
}
|
||||
|
||||
#[new]
|
||||
#[pyo3(signature = (left=0, right=0))]
|
||||
fn new(left: usize, right: usize) -> (Self, PyDecoder) {
|
||||
(PyStrip {}, Strip::new(left, right).into())
|
||||
}
|
||||
}
|
||||
|
||||
/// Metaspace Decoder
|
||||
///
|
||||
/// Args:
|
||||
@ -497,6 +551,8 @@ pub fn decoders(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||
m.add_class::<PyReplaceDec>()?;
|
||||
m.add_class::<PyWordPieceDec>()?;
|
||||
m.add_class::<PyByteFallbackDec>()?;
|
||||
m.add_class::<PyFuseDec>()?;
|
||||
m.add_class::<PyStrip>()?;
|
||||
m.add_class::<PyMetaspaceDec>()?;
|
||||
m.add_class::<PyBPEDecoder>()?;
|
||||
m.add_class::<PyCTCDecoder>()?;
|
||||
|
@ -13,6 +13,8 @@ from tokenizers.decoders import (
|
||||
WordPiece,
|
||||
ByteFallback,
|
||||
Replace,
|
||||
Strip,
|
||||
Fuse,
|
||||
)
|
||||
|
||||
|
||||
@ -94,6 +96,30 @@ class TestByteFallback:
|
||||
assert decoder.decode(["<0xE5>", "<0x8f>", "<0xab>", "a"]) == "叫a"
|
||||
|
||||
|
||||
class TestFuse:
|
||||
def test_instantiate(self):
|
||||
assert Fuse() is not None
|
||||
assert isinstance(Fuse(), Decoder)
|
||||
assert isinstance(Fuse(), Fuse)
|
||||
assert isinstance(pickle.loads(pickle.dumps(Fuse())), Fuse)
|
||||
|
||||
def test_decoding(self):
|
||||
decoder = Fuse()
|
||||
assert decoder.decode(["My", " na", "me"]) == "My name"
|
||||
|
||||
|
||||
class TestStrip:
|
||||
def test_instantiate(self):
|
||||
assert Strip(left=0, right=0) is not None
|
||||
assert isinstance(Strip(left=0, right=0), Decoder)
|
||||
assert isinstance(Strip(left=0, right=0), Strip)
|
||||
assert isinstance(pickle.loads(pickle.dumps(Strip(left=0, right=0))), Strip)
|
||||
|
||||
def test_decoding(self):
|
||||
decoder = Strip(left=1, right=0)
|
||||
assert decoder.decode(["My", " na", "me"]) == "ynae"
|
||||
|
||||
|
||||
class TestMetaspace:
|
||||
def test_instantiate(self):
|
||||
assert Metaspace() is not None
|
||||
|
@ -67,6 +67,7 @@ thiserror = "1.0.30"
|
||||
fancy-regex = { version = "0.10", optional = true}
|
||||
getrandom = { version = "0.2.6" }
|
||||
esaxx-rs = { version = "0.1", default-features = false, features=[]}
|
||||
monostate = "0.1.5"
|
||||
|
||||
[features]
|
||||
default = ["progressbar", "http", "cli", "onig", "esaxx_fast"]
|
||||
|
@ -1,4 +1,5 @@
|
||||
use crate::tokenizer::{Decoder, Result};
|
||||
use monostate::MustBe;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
@ -6,13 +7,17 @@ use serde::{Deserialize, Serialize};
|
||||
/// ByteFallback is a simple trick which converts tokens looking like `<0x61>`
|
||||
/// to pure bytes, and attempts to make them into a string. If the tokens
|
||||
/// cannot be decoded you will get <20> instead for each inconvertable byte token
|
||||
#[serde(tag = "type")]
|
||||
#[non_exhaustive]
|
||||
pub struct ByteFallback {}
|
||||
pub struct ByteFallback {
|
||||
#[serde(rename = "type")]
|
||||
type_: MustBe!("ByteFallback"),
|
||||
}
|
||||
|
||||
impl ByteFallback {
|
||||
pub fn new() -> Self {
|
||||
Self {}
|
||||
Self {
|
||||
type_: MustBe!("ByteFallback"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
43
tokenizers/src/decoders/fuse.rs
Normal file
43
tokenizers/src/decoders/fuse.rs
Normal file
@ -0,0 +1,43 @@
|
||||
use crate::tokenizer::{Decoder, Result};
|
||||
use monostate::MustBe;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Default)]
|
||||
/// Fuse simply fuses all tokens into one big string.
|
||||
/// It's usually the last decoding step anyway, but this
|
||||
/// decoder exists incase some decoders need to happen after that
|
||||
/// step
|
||||
#[non_exhaustive]
|
||||
pub struct Fuse {
|
||||
#[serde(rename = "type")]
|
||||
type_: MustBe!("Fuse"),
|
||||
}
|
||||
|
||||
impl Fuse {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
type_: MustBe!("Fuse"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Decoder for Fuse {
|
||||
fn decode_chain(&self, tokens: Vec<String>) -> Result<Vec<String>> {
|
||||
let new_string = tokens.join("");
|
||||
Ok(vec![new_string])
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn decode() {
|
||||
let decoder = Fuse::new();
|
||||
let res = decoder
|
||||
.decode_chain(vec!["Hey".into(), " friend!".into()])
|
||||
.unwrap();
|
||||
assert_eq!(res, vec!["Hey friend!"]);
|
||||
}
|
||||
}
|
@ -1,7 +1,9 @@
|
||||
pub mod bpe;
|
||||
pub mod byte_fallback;
|
||||
pub mod ctc;
|
||||
pub mod fuse;
|
||||
pub mod sequence;
|
||||
pub mod strip;
|
||||
pub mod wordpiece;
|
||||
|
||||
// Re-export these as decoders
|
||||
@ -13,7 +15,9 @@ use serde::{Deserialize, Serialize};
|
||||
use crate::decoders::bpe::BPEDecoder;
|
||||
use crate::decoders::byte_fallback::ByteFallback;
|
||||
use crate::decoders::ctc::CTC;
|
||||
use crate::decoders::fuse::Fuse;
|
||||
use crate::decoders::sequence::Sequence;
|
||||
use crate::decoders::strip::Strip;
|
||||
use crate::decoders::wordpiece::WordPiece;
|
||||
use crate::normalizers::replace::Replace;
|
||||
use crate::pre_tokenizers::byte_level::ByteLevel;
|
||||
@ -30,10 +34,8 @@ pub enum DecoderWrapper {
|
||||
CTC(CTC),
|
||||
Sequence(Sequence),
|
||||
Replace(Replace),
|
||||
// XXX: This is an untagged enum, which unfortunately means order
|
||||
// is **CRITICAL**. We absolutely need to make sure order is correct.
|
||||
// Since byte fallback is parameter free, is **has** to be last, and will
|
||||
// unfortunately match pretty much everything.
|
||||
Fuse(Fuse),
|
||||
Strip(Strip),
|
||||
ByteFallback(ByteFallback),
|
||||
}
|
||||
|
||||
@ -48,6 +50,8 @@ impl Decoder for DecoderWrapper {
|
||||
Self::Sequence(seq) => seq.decode_chain(tokens),
|
||||
Self::Replace(seq) => seq.decode_chain(tokens),
|
||||
Self::ByteFallback(bf) => bf.decode_chain(tokens),
|
||||
Self::Strip(bf) => bf.decode_chain(tokens),
|
||||
Self::Fuse(bf) => bf.decode_chain(tokens),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -55,6 +59,8 @@ impl Decoder for DecoderWrapper {
|
||||
impl_enum_from!(BPEDecoder, DecoderWrapper, BPE);
|
||||
impl_enum_from!(ByteLevel, DecoderWrapper, ByteLevel);
|
||||
impl_enum_from!(ByteFallback, DecoderWrapper, ByteFallback);
|
||||
impl_enum_from!(Fuse, DecoderWrapper, Fuse);
|
||||
impl_enum_from!(Strip, DecoderWrapper, Strip);
|
||||
impl_enum_from!(Metaspace, DecoderWrapper, Metaspace);
|
||||
impl_enum_from!(WordPiece, DecoderWrapper, WordPiece);
|
||||
impl_enum_from!(CTC, DecoderWrapper, CTC);
|
||||
@ -72,4 +78,18 @@ mod tests {
|
||||
let serialized = serde_json::to_string(&decoder).unwrap();
|
||||
assert_eq!(serialized, json);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decoder_serialization_other_no_arg() {
|
||||
let json = r#"{"type":"Sequence","decoders":[{"type":"Fuse"},{"type":"Metaspace","replacement":"▁","add_prefix_space":true}]}"#;
|
||||
let decoder: DecoderWrapper = serde_json::from_str(json).unwrap();
|
||||
let serialized = serde_json::to_string(&decoder).unwrap();
|
||||
assert_eq!(serialized, json);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decoder_serialization_no_decode() {
|
||||
let json = r#"{"type":"Sequence","decoders":[{},{"type":"Metaspace","replacement":"▁","add_prefix_space":true}]}"#;
|
||||
assert!(serde_json::from_str::<DecoderWrapper>(json).is_err());
|
||||
}
|
||||
}
|
||||
|
49
tokenizers/src/decoders/strip.rs
Normal file
49
tokenizers/src/decoders/strip.rs
Normal file
@ -0,0 +1,49 @@
|
||||
use crate::tokenizer::{Decoder, Result};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Deserialize, Clone, Debug, Serialize, Default)]
|
||||
/// Strip is a simple trick which converts tokens looking like `<0x61>`
|
||||
/// to pure bytes, and attempts to make them into a string. If the tokens
|
||||
/// cannot be decoded you will get <20> instead for each inconvertable byte token
|
||||
#[serde(tag = "type")]
|
||||
#[non_exhaustive]
|
||||
pub struct Strip {
|
||||
pub left: usize,
|
||||
pub right: usize,
|
||||
}
|
||||
|
||||
impl Strip {
|
||||
pub fn new(left: usize, right: usize) -> Self {
|
||||
Self { left, right }
|
||||
}
|
||||
}
|
||||
|
||||
impl Decoder for Strip {
|
||||
fn decode_chain(&self, tokens: Vec<String>) -> Result<Vec<String>> {
|
||||
Ok(tokens
|
||||
.into_iter()
|
||||
.map(|token| token[self.left..token.len() - self.right].to_string())
|
||||
.collect())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn decode() {
|
||||
let decoder = Strip::new(1, 0);
|
||||
let res = decoder
|
||||
.decode_chain(vec!["Hey".into(), " friend!".into()])
|
||||
.unwrap();
|
||||
assert_eq!(res, vec!["ey", "friend!"]);
|
||||
|
||||
let decoder = Strip::new(0, 1);
|
||||
let res = decoder
|
||||
.decode_chain(vec!["Hey".into(), " friend!".into()])
|
||||
.unwrap();
|
||||
assert_eq!(res, vec!["He", " friend"]);
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user