diff --git a/bindings/node/lib/bindings/decoders.d.ts b/bindings/node/lib/bindings/decoders.d.ts index d6e809b8..be9dc3f7 100644 --- a/bindings/node/lib/bindings/decoders.d.ts +++ b/bindings/node/lib/bindings/decoders.d.ts @@ -35,6 +35,18 @@ export function wordPieceDecoder(prefix?: string, cleanup?: boolean): Decoder; */ export function byteFallbackDecoder(): Decoder; +/** + * Instantiate a new Fuse Decoder which fuses all tokens into one string + */ +export function fuseDecoder(): Decoder; + +/** + * Instantiate a new Strip Decoder + * @param [left] The number of chars to remove from the left of each token + * @param [right] The number of chars to remove from the right of each token + */ +export function stripDecoder(left: number, right: number): Decoder; + /** * Instantiate a new Metaspace * diff --git a/bindings/node/lib/bindings/decoders.js b/bindings/node/lib/bindings/decoders.js index 9d10b07f..14f452cd 100644 --- a/bindings/node/lib/bindings/decoders.js +++ b/bindings/node/lib/bindings/decoders.js @@ -5,6 +5,8 @@ module.exports = { replaceDecoder: native.decoders_Replace, wordPieceDecoder: native.decoders_WordPiece, byteFallbackDecoder: native.decoders_ByteFallback, + fuseDecoder: native.decoders_Fuse, + stripDecoder: native.decoders_Strip, metaspaceDecoder: native.decoders_Metaspace, bpeDecoder: native.decoders_BPEDecoder, ctcDecoder: native.decoders_CTC, diff --git a/bindings/node/lib/bindings/decoders.test.ts b/bindings/node/lib/bindings/decoders.test.ts index 7e31b576..624803b1 100644 --- a/bindings/node/lib/bindings/decoders.test.ts +++ b/bindings/node/lib/bindings/decoders.test.ts @@ -2,9 +2,11 @@ import { bpeDecoder, byteFallbackDecoder, ctcDecoder, + fuseDecoder, metaspaceDecoder, replaceDecoder, sequenceDecoder, + stripDecoder, wordPieceDecoder, } from "./decoders"; @@ -51,6 +53,26 @@ describe("replaceDecoder", () => { }); }); +describe("fuseDecoder", () => { + it("accepts `undefined` as first parameter", () => { + expect(fuseDecoder()).toBeDefined(); + }); + + it("can decode arrays of strings", () => { + expect(fuseDecoder().decode(["Hel", "lo"])).toEqual("Hello"); + }); +}); + +describe("stripDecoder", () => { + it("accepts `undefined` as first parameter", () => { + expect(stripDecoder(0, 0)).toBeDefined(); + }); + + it("can decode arrays of strings", () => { + expect(stripDecoder(1, 0).decode(["Hel", "lo"])).toEqual("elo"); + }); +}); + describe("metaspaceDecoder", () => { it("accepts `undefined` as first parameter", () => { expect(metaspaceDecoder(undefined)).toBeDefined(); diff --git a/bindings/node/native/src/decoders.rs b/bindings/node/native/src/decoders.rs index c8433b8d..fa3c712e 100644 --- a/bindings/node/native/src/decoders.rs +++ b/bindings/node/native/src/decoders.rs @@ -96,6 +96,26 @@ fn byte_fallback(mut cx: FunctionContext) -> JsResult { Ok(decoder) } +/// fuse() +fn fuse(mut cx: FunctionContext) -> JsResult { + let mut decoder = JsDecoder::new::<_, JsDecoder, _>(&mut cx, vec![])?; + let guard = cx.lock(); + decoder.borrow_mut(&guard).decoder = Some(Arc::new(tk::decoders::fuse::Fuse::new().into())); + Ok(decoder) +} + +/// strip() +fn strip(mut cx: FunctionContext) -> JsResult { + let left: usize = cx.extract(0)?; + let right: usize = cx.extract(1)?; + let mut decoder = JsDecoder::new::<_, JsDecoder, _>(&mut cx, vec![])?; + let guard = cx.lock(); + decoder.borrow_mut(&guard).decoder = Some(Arc::new( + tk::decoders::strip::Strip::new(left, right).into(), + )); + Ok(decoder) +} + /// metaspace(replacement: String = "_", add_prefix_space: bool = true) fn metaspace(mut cx: FunctionContext) -> JsResult { let replacement = cx.extract_opt::(0)?.unwrap_or('▁'); @@ -173,6 +193,8 @@ pub fn register(m: &mut ModuleContext, prefix: &str) -> NeonResult<()> { m.export_function(&format!("{}_Replace", prefix), replace)?; m.export_function(&format!("{}_WordPiece", prefix), wordpiece)?; m.export_function(&format!("{}_ByteFallback", prefix), byte_fallback)?; + m.export_function(&format!("{}_Fuse", prefix), fuse)?; + m.export_function(&format!("{}_Strip", prefix), strip)?; m.export_function(&format!("{}_Metaspace", prefix), metaspace)?; m.export_function(&format!("{}_BPEDecoder", prefix), bpe_decoder)?; m.export_function(&format!("{}_CTC", prefix), ctc_decoder)?; diff --git a/bindings/python/py_src/tokenizers/decoders/__init__.py b/bindings/python/py_src/tokenizers/decoders/__init__.py index 44c906e4..a717379c 100644 --- a/bindings/python/py_src/tokenizers/decoders/__init__.py +++ b/bindings/python/py_src/tokenizers/decoders/__init__.py @@ -6,6 +6,8 @@ ByteLevel = decoders.ByteLevel Replace = decoders.Replace WordPiece = decoders.WordPiece ByteFallback = decoders.ByteFallback +Fuse = decoders.Fuse +Strip = decoders.Strip Metaspace = decoders.Metaspace BPEDecoder = decoders.BPEDecoder CTC = decoders.CTC diff --git a/bindings/python/py_src/tokenizers/decoders/__init__.pyi b/bindings/python/py_src/tokenizers/decoders/__init__.pyi index ea174e92..21fe746a 100644 --- a/bindings/python/py_src/tokenizers/decoders/__init__.pyi +++ b/bindings/python/py_src/tokenizers/decoders/__init__.pyi @@ -121,6 +121,29 @@ class CTC(Decoder): """ pass +class Fuse(Decoder): + """ + Fuse Decoder + Fuse simply fuses every token into a single string. + This is the last step of decoding, this decoder exists only if + there is need to add other decoders *after* the fusion + """ + + def __init__(self): + pass + def decode(self, tokens): + """ + Decode the given list of tokens to a final string + + Args: + tokens (:obj:`List[str]`): + The list of tokens to decode + + Returns: + :obj:`str`: The decoded string + """ + pass + class Metaspace(Decoder): """ Metaspace Decoder @@ -197,6 +220,27 @@ class Sequence(Decoder): """ pass +class Strip(Decoder): + """ + Strip normalizer + Strips n left characters of each token, or n right characters of each token + """ + + def __init__(self, left=0, right=0): + pass + def decode(self, tokens): + """ + Decode the given list of tokens to a final string + + Args: + tokens (:obj:`List[str]`): + The list of tokens to decode + + Returns: + :obj:`str`: The decoded string + """ + pass + class WordPiece(Decoder): """ WordPiece Decoder diff --git a/bindings/python/src/decoders.rs b/bindings/python/src/decoders.rs index 85d76ec0..e1b5bd79 100644 --- a/bindings/python/src/decoders.rs +++ b/bindings/python/src/decoders.rs @@ -11,8 +11,10 @@ use tk::decoders::bpe::BPEDecoder; use tk::decoders::byte_fallback::ByteFallback; use tk::decoders::byte_level::ByteLevel; use tk::decoders::ctc::CTC; +use tk::decoders::fuse::Fuse; use tk::decoders::metaspace::Metaspace; use tk::decoders::sequence::Sequence; +use tk::decoders::strip::Strip; use tk::decoders::wordpiece::WordPiece; use tk::decoders::DecoderWrapper; use tk::normalizers::replace::Replace; @@ -47,6 +49,8 @@ impl PyDecoder { DecoderWrapper::ByteFallback(_) => { Py::new(py, (PyByteFallbackDec {}, base))?.into_py(py) } + DecoderWrapper::Strip(_) => Py::new(py, (PyStrip {}, base))?.into_py(py), + DecoderWrapper::Fuse(_) => Py::new(py, (PyFuseDec {}, base))?.into_py(py), DecoderWrapper::ByteLevel(_) => Py::new(py, (PyByteLevelDec {}, base))?.into_py(py), DecoderWrapper::Replace(_) => Py::new(py, (PyReplaceDec {}, base))?.into_py(py), DecoderWrapper::BPE(_) => Py::new(py, (PyBPEDecoder {}, base))?.into_py(py), @@ -238,6 +242,56 @@ impl PyByteFallbackDec { } } +/// Fuse Decoder +/// Fuse simply fuses every token into a single string. +/// This is the last step of decoding, this decoder exists only if +/// there is need to add other decoders *after* the fusion +#[pyclass(extends=PyDecoder, module = "tokenizers.decoders", name = "Fuse")] +#[pyo3(text_signature = "(self)")] +pub struct PyFuseDec {} +#[pymethods] +impl PyFuseDec { + #[new] + #[pyo3(signature = ())] + fn new() -> (Self, PyDecoder) { + (PyFuseDec {}, Fuse::new().into()) + } +} + +/// Strip normalizer +/// Strips n left characters of each token, or n right characters of each token +#[pyclass(extends=PyDecoder, module = "tokenizers.decoders", name = "Strip")] +#[pyo3(text_signature = "(self, left=0, right=0)")] +pub struct PyStrip {} +#[pymethods] +impl PyStrip { + #[getter] + fn get_left(self_: PyRef) -> usize { + getter!(self_, Strip, left) + } + + #[setter] + fn set_left(self_: PyRef, left: usize) { + setter!(self_, Strip, left, left) + } + + #[getter] + fn get_right(self_: PyRef) -> usize { + getter!(self_, Strip, right) + } + + #[setter] + fn set_right(self_: PyRef, right: usize) { + setter!(self_, Strip, right, right) + } + + #[new] + #[pyo3(signature = (left=0, right=0))] + fn new(left: usize, right: usize) -> (Self, PyDecoder) { + (PyStrip {}, Strip::new(left, right).into()) + } +} + /// Metaspace Decoder /// /// Args: @@ -497,6 +551,8 @@ pub fn decoders(_py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/bindings/python/tests/bindings/test_decoders.py b/bindings/python/tests/bindings/test_decoders.py index f5ded6af..fe21e022 100644 --- a/bindings/python/tests/bindings/test_decoders.py +++ b/bindings/python/tests/bindings/test_decoders.py @@ -13,6 +13,8 @@ from tokenizers.decoders import ( WordPiece, ByteFallback, Replace, + Strip, + Fuse, ) @@ -94,6 +96,30 @@ class TestByteFallback: assert decoder.decode(["<0xE5>", "<0x8f>", "<0xab>", "a"]) == "叫a" +class TestFuse: + def test_instantiate(self): + assert Fuse() is not None + assert isinstance(Fuse(), Decoder) + assert isinstance(Fuse(), Fuse) + assert isinstance(pickle.loads(pickle.dumps(Fuse())), Fuse) + + def test_decoding(self): + decoder = Fuse() + assert decoder.decode(["My", " na", "me"]) == "My name" + + +class TestStrip: + def test_instantiate(self): + assert Strip(left=0, right=0) is not None + assert isinstance(Strip(left=0, right=0), Decoder) + assert isinstance(Strip(left=0, right=0), Strip) + assert isinstance(pickle.loads(pickle.dumps(Strip(left=0, right=0))), Strip) + + def test_decoding(self): + decoder = Strip(left=1, right=0) + assert decoder.decode(["My", " na", "me"]) == "ynae" + + class TestMetaspace: def test_instantiate(self): assert Metaspace() is not None diff --git a/tokenizers/Cargo.toml b/tokenizers/Cargo.toml index 1aef33cf..50a1eb4e 100644 --- a/tokenizers/Cargo.toml +++ b/tokenizers/Cargo.toml @@ -67,6 +67,7 @@ thiserror = "1.0.30" fancy-regex = { version = "0.10", optional = true} getrandom = { version = "0.2.6" } esaxx-rs = { version = "0.1", default-features = false, features=[]} +monostate = "0.1.5" [features] default = ["progressbar", "http", "cli", "onig", "esaxx_fast"] diff --git a/tokenizers/src/decoders/byte_fallback.rs b/tokenizers/src/decoders/byte_fallback.rs index 40e4d137..16788dec 100644 --- a/tokenizers/src/decoders/byte_fallback.rs +++ b/tokenizers/src/decoders/byte_fallback.rs @@ -1,4 +1,5 @@ use crate::tokenizer::{Decoder, Result}; +use monostate::MustBe; use serde::{Deserialize, Serialize}; @@ -6,13 +7,17 @@ use serde::{Deserialize, Serialize}; /// ByteFallback is a simple trick which converts tokens looking like `<0x61>` /// to pure bytes, and attempts to make them into a string. If the tokens /// cannot be decoded you will get � instead for each inconvertable byte token -#[serde(tag = "type")] #[non_exhaustive] -pub struct ByteFallback {} +pub struct ByteFallback { + #[serde(rename = "type")] + type_: MustBe!("ByteFallback"), +} impl ByteFallback { pub fn new() -> Self { - Self {} + Self { + type_: MustBe!("ByteFallback"), + } } } diff --git a/tokenizers/src/decoders/fuse.rs b/tokenizers/src/decoders/fuse.rs new file mode 100644 index 00000000..5e4a1c11 --- /dev/null +++ b/tokenizers/src/decoders/fuse.rs @@ -0,0 +1,43 @@ +use crate::tokenizer::{Decoder, Result}; +use monostate::MustBe; +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Debug, Serialize, Deserialize, Default)] +/// Fuse simply fuses all tokens into one big string. +/// It's usually the last decoding step anyway, but this +/// decoder exists incase some decoders need to happen after that +/// step +#[non_exhaustive] +pub struct Fuse { + #[serde(rename = "type")] + type_: MustBe!("Fuse"), +} + +impl Fuse { + pub fn new() -> Self { + Self { + type_: MustBe!("Fuse"), + } + } +} + +impl Decoder for Fuse { + fn decode_chain(&self, tokens: Vec) -> Result> { + let new_string = tokens.join(""); + Ok(vec![new_string]) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn decode() { + let decoder = Fuse::new(); + let res = decoder + .decode_chain(vec!["Hey".into(), " friend!".into()]) + .unwrap(); + assert_eq!(res, vec!["Hey friend!"]); + } +} diff --git a/tokenizers/src/decoders/mod.rs b/tokenizers/src/decoders/mod.rs index 9c85928c..94204b8f 100644 --- a/tokenizers/src/decoders/mod.rs +++ b/tokenizers/src/decoders/mod.rs @@ -1,7 +1,9 @@ pub mod bpe; pub mod byte_fallback; pub mod ctc; +pub mod fuse; pub mod sequence; +pub mod strip; pub mod wordpiece; // Re-export these as decoders @@ -13,7 +15,9 @@ use serde::{Deserialize, Serialize}; use crate::decoders::bpe::BPEDecoder; use crate::decoders::byte_fallback::ByteFallback; use crate::decoders::ctc::CTC; +use crate::decoders::fuse::Fuse; use crate::decoders::sequence::Sequence; +use crate::decoders::strip::Strip; use crate::decoders::wordpiece::WordPiece; use crate::normalizers::replace::Replace; use crate::pre_tokenizers::byte_level::ByteLevel; @@ -30,10 +34,8 @@ pub enum DecoderWrapper { CTC(CTC), Sequence(Sequence), Replace(Replace), - // XXX: This is an untagged enum, which unfortunately means order - // is **CRITICAL**. We absolutely need to make sure order is correct. - // Since byte fallback is parameter free, is **has** to be last, and will - // unfortunately match pretty much everything. + Fuse(Fuse), + Strip(Strip), ByteFallback(ByteFallback), } @@ -48,6 +50,8 @@ impl Decoder for DecoderWrapper { Self::Sequence(seq) => seq.decode_chain(tokens), Self::Replace(seq) => seq.decode_chain(tokens), Self::ByteFallback(bf) => bf.decode_chain(tokens), + Self::Strip(bf) => bf.decode_chain(tokens), + Self::Fuse(bf) => bf.decode_chain(tokens), } } } @@ -55,6 +59,8 @@ impl Decoder for DecoderWrapper { impl_enum_from!(BPEDecoder, DecoderWrapper, BPE); impl_enum_from!(ByteLevel, DecoderWrapper, ByteLevel); impl_enum_from!(ByteFallback, DecoderWrapper, ByteFallback); +impl_enum_from!(Fuse, DecoderWrapper, Fuse); +impl_enum_from!(Strip, DecoderWrapper, Strip); impl_enum_from!(Metaspace, DecoderWrapper, Metaspace); impl_enum_from!(WordPiece, DecoderWrapper, WordPiece); impl_enum_from!(CTC, DecoderWrapper, CTC); @@ -72,4 +78,18 @@ mod tests { let serialized = serde_json::to_string(&decoder).unwrap(); assert_eq!(serialized, json); } + + #[test] + fn decoder_serialization_other_no_arg() { + let json = r#"{"type":"Sequence","decoders":[{"type":"Fuse"},{"type":"Metaspace","replacement":"▁","add_prefix_space":true}]}"#; + let decoder: DecoderWrapper = serde_json::from_str(json).unwrap(); + let serialized = serde_json::to_string(&decoder).unwrap(); + assert_eq!(serialized, json); + } + + #[test] + fn decoder_serialization_no_decode() { + let json = r#"{"type":"Sequence","decoders":[{},{"type":"Metaspace","replacement":"▁","add_prefix_space":true}]}"#; + assert!(serde_json::from_str::(json).is_err()); + } } diff --git a/tokenizers/src/decoders/strip.rs b/tokenizers/src/decoders/strip.rs new file mode 100644 index 00000000..1691c707 --- /dev/null +++ b/tokenizers/src/decoders/strip.rs @@ -0,0 +1,49 @@ +use crate::tokenizer::{Decoder, Result}; + +use serde::{Deserialize, Serialize}; + +#[derive(Deserialize, Clone, Debug, Serialize, Default)] +/// Strip is a simple trick which converts tokens looking like `<0x61>` +/// to pure bytes, and attempts to make them into a string. If the tokens +/// cannot be decoded you will get � instead for each inconvertable byte token +#[serde(tag = "type")] +#[non_exhaustive] +pub struct Strip { + pub left: usize, + pub right: usize, +} + +impl Strip { + pub fn new(left: usize, right: usize) -> Self { + Self { left, right } + } +} + +impl Decoder for Strip { + fn decode_chain(&self, tokens: Vec) -> Result> { + Ok(tokens + .into_iter() + .map(|token| token[self.left..token.len() - self.right].to_string()) + .collect()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn decode() { + let decoder = Strip::new(1, 0); + let res = decoder + .decode_chain(vec!["Hey".into(), " friend!".into()]) + .unwrap(); + assert_eq!(res, vec!["ey", "friend!"]); + + let decoder = Strip::new(0, 1); + let res = decoder + .decode_chain(vec!["Hey".into(), " friend!".into()]) + .unwrap(); + assert_eq!(res, vec!["He", " friend"]); + } +}