Adding 2 new decoders: (#1196)

* Adding 2 new decoders:

- Fuse will simply concatenate all tokens into 1 string
- Strip will remove n char from left or right

Sequence(Replace("_", " "), Fuse(), Strip(1, 0)) should be what we want
for the `Metaspace` thing.

- Note: Added a new dependency from better parsing of decoders.
This is due to untagged enums which can match anything the `MustBe`
ensure there's no issue between Fuse and ByteFallback.
Since both are new the chances for backward incompatibility is low.

* Fixing picking/unpickling (using default args.).

* Stub.

* Black.

* Fixing node.
This commit is contained in:
Nicolas Patry
2023-03-24 00:50:54 +01:00
committed by GitHub
parent d2c8190a0f
commit e4aea890d5
13 changed files with 311 additions and 7 deletions

View File

@ -35,6 +35,18 @@ export function wordPieceDecoder(prefix?: string, cleanup?: boolean): Decoder;
*/ */
export function byteFallbackDecoder(): Decoder; export function byteFallbackDecoder(): Decoder;
/**
* Instantiate a new Fuse Decoder which fuses all tokens into one string
*/
export function fuseDecoder(): Decoder;
/**
* Instantiate a new Strip Decoder
* @param [left] The number of chars to remove from the left of each token
* @param [right] The number of chars to remove from the right of each token
*/
export function stripDecoder(left: number, right: number): Decoder;
/** /**
* Instantiate a new Metaspace * Instantiate a new Metaspace
* *

View File

@ -5,6 +5,8 @@ module.exports = {
replaceDecoder: native.decoders_Replace, replaceDecoder: native.decoders_Replace,
wordPieceDecoder: native.decoders_WordPiece, wordPieceDecoder: native.decoders_WordPiece,
byteFallbackDecoder: native.decoders_ByteFallback, byteFallbackDecoder: native.decoders_ByteFallback,
fuseDecoder: native.decoders_Fuse,
stripDecoder: native.decoders_Strip,
metaspaceDecoder: native.decoders_Metaspace, metaspaceDecoder: native.decoders_Metaspace,
bpeDecoder: native.decoders_BPEDecoder, bpeDecoder: native.decoders_BPEDecoder,
ctcDecoder: native.decoders_CTC, ctcDecoder: native.decoders_CTC,

View File

@ -2,9 +2,11 @@ import {
bpeDecoder, bpeDecoder,
byteFallbackDecoder, byteFallbackDecoder,
ctcDecoder, ctcDecoder,
fuseDecoder,
metaspaceDecoder, metaspaceDecoder,
replaceDecoder, replaceDecoder,
sequenceDecoder, sequenceDecoder,
stripDecoder,
wordPieceDecoder, wordPieceDecoder,
} from "./decoders"; } from "./decoders";
@ -51,6 +53,26 @@ describe("replaceDecoder", () => {
}); });
}); });
describe("fuseDecoder", () => {
it("accepts `undefined` as first parameter", () => {
expect(fuseDecoder()).toBeDefined();
});
it("can decode arrays of strings", () => {
expect(fuseDecoder().decode(["Hel", "lo"])).toEqual("Hello");
});
});
describe("stripDecoder", () => {
it("accepts `undefined` as first parameter", () => {
expect(stripDecoder(0, 0)).toBeDefined();
});
it("can decode arrays of strings", () => {
expect(stripDecoder(1, 0).decode(["Hel", "lo"])).toEqual("elo");
});
});
describe("metaspaceDecoder", () => { describe("metaspaceDecoder", () => {
it("accepts `undefined` as first parameter", () => { it("accepts `undefined` as first parameter", () => {
expect(metaspaceDecoder(undefined)).toBeDefined(); expect(metaspaceDecoder(undefined)).toBeDefined();

View File

@ -96,6 +96,26 @@ fn byte_fallback(mut cx: FunctionContext) -> JsResult<JsDecoder> {
Ok(decoder) Ok(decoder)
} }
/// fuse()
fn fuse(mut cx: FunctionContext) -> JsResult<JsDecoder> {
let mut decoder = JsDecoder::new::<_, JsDecoder, _>(&mut cx, vec![])?;
let guard = cx.lock();
decoder.borrow_mut(&guard).decoder = Some(Arc::new(tk::decoders::fuse::Fuse::new().into()));
Ok(decoder)
}
/// strip()
fn strip(mut cx: FunctionContext) -> JsResult<JsDecoder> {
let left: usize = cx.extract(0)?;
let right: usize = cx.extract(1)?;
let mut decoder = JsDecoder::new::<_, JsDecoder, _>(&mut cx, vec![])?;
let guard = cx.lock();
decoder.borrow_mut(&guard).decoder = Some(Arc::new(
tk::decoders::strip::Strip::new(left, right).into(),
));
Ok(decoder)
}
/// metaspace(replacement: String = "_", add_prefix_space: bool = true) /// metaspace(replacement: String = "_", add_prefix_space: bool = true)
fn metaspace(mut cx: FunctionContext) -> JsResult<JsDecoder> { fn metaspace(mut cx: FunctionContext) -> JsResult<JsDecoder> {
let replacement = cx.extract_opt::<char>(0)?.unwrap_or('▁'); let replacement = cx.extract_opt::<char>(0)?.unwrap_or('▁');
@ -173,6 +193,8 @@ pub fn register(m: &mut ModuleContext, prefix: &str) -> NeonResult<()> {
m.export_function(&format!("{}_Replace", prefix), replace)?; m.export_function(&format!("{}_Replace", prefix), replace)?;
m.export_function(&format!("{}_WordPiece", prefix), wordpiece)?; m.export_function(&format!("{}_WordPiece", prefix), wordpiece)?;
m.export_function(&format!("{}_ByteFallback", prefix), byte_fallback)?; m.export_function(&format!("{}_ByteFallback", prefix), byte_fallback)?;
m.export_function(&format!("{}_Fuse", prefix), fuse)?;
m.export_function(&format!("{}_Strip", prefix), strip)?;
m.export_function(&format!("{}_Metaspace", prefix), metaspace)?; m.export_function(&format!("{}_Metaspace", prefix), metaspace)?;
m.export_function(&format!("{}_BPEDecoder", prefix), bpe_decoder)?; m.export_function(&format!("{}_BPEDecoder", prefix), bpe_decoder)?;
m.export_function(&format!("{}_CTC", prefix), ctc_decoder)?; m.export_function(&format!("{}_CTC", prefix), ctc_decoder)?;

View File

@ -6,6 +6,8 @@ ByteLevel = decoders.ByteLevel
Replace = decoders.Replace Replace = decoders.Replace
WordPiece = decoders.WordPiece WordPiece = decoders.WordPiece
ByteFallback = decoders.ByteFallback ByteFallback = decoders.ByteFallback
Fuse = decoders.Fuse
Strip = decoders.Strip
Metaspace = decoders.Metaspace Metaspace = decoders.Metaspace
BPEDecoder = decoders.BPEDecoder BPEDecoder = decoders.BPEDecoder
CTC = decoders.CTC CTC = decoders.CTC

View File

@ -121,6 +121,29 @@ class CTC(Decoder):
""" """
pass pass
class Fuse(Decoder):
"""
Fuse Decoder
Fuse simply fuses every token into a single string.
This is the last step of decoding, this decoder exists only if
there is need to add other decoders *after* the fusion
"""
def __init__(self):
pass
def decode(self, tokens):
"""
Decode the given list of tokens to a final string
Args:
tokens (:obj:`List[str]`):
The list of tokens to decode
Returns:
:obj:`str`: The decoded string
"""
pass
class Metaspace(Decoder): class Metaspace(Decoder):
""" """
Metaspace Decoder Metaspace Decoder
@ -197,6 +220,27 @@ class Sequence(Decoder):
""" """
pass pass
class Strip(Decoder):
"""
Strip normalizer
Strips n left characters of each token, or n right characters of each token
"""
def __init__(self, left=0, right=0):
pass
def decode(self, tokens):
"""
Decode the given list of tokens to a final string
Args:
tokens (:obj:`List[str]`):
The list of tokens to decode
Returns:
:obj:`str`: The decoded string
"""
pass
class WordPiece(Decoder): class WordPiece(Decoder):
""" """
WordPiece Decoder WordPiece Decoder

View File

@ -11,8 +11,10 @@ use tk::decoders::bpe::BPEDecoder;
use tk::decoders::byte_fallback::ByteFallback; use tk::decoders::byte_fallback::ByteFallback;
use tk::decoders::byte_level::ByteLevel; use tk::decoders::byte_level::ByteLevel;
use tk::decoders::ctc::CTC; use tk::decoders::ctc::CTC;
use tk::decoders::fuse::Fuse;
use tk::decoders::metaspace::Metaspace; use tk::decoders::metaspace::Metaspace;
use tk::decoders::sequence::Sequence; use tk::decoders::sequence::Sequence;
use tk::decoders::strip::Strip;
use tk::decoders::wordpiece::WordPiece; use tk::decoders::wordpiece::WordPiece;
use tk::decoders::DecoderWrapper; use tk::decoders::DecoderWrapper;
use tk::normalizers::replace::Replace; use tk::normalizers::replace::Replace;
@ -47,6 +49,8 @@ impl PyDecoder {
DecoderWrapper::ByteFallback(_) => { DecoderWrapper::ByteFallback(_) => {
Py::new(py, (PyByteFallbackDec {}, base))?.into_py(py) Py::new(py, (PyByteFallbackDec {}, base))?.into_py(py)
} }
DecoderWrapper::Strip(_) => Py::new(py, (PyStrip {}, base))?.into_py(py),
DecoderWrapper::Fuse(_) => Py::new(py, (PyFuseDec {}, base))?.into_py(py),
DecoderWrapper::ByteLevel(_) => Py::new(py, (PyByteLevelDec {}, base))?.into_py(py), DecoderWrapper::ByteLevel(_) => Py::new(py, (PyByteLevelDec {}, base))?.into_py(py),
DecoderWrapper::Replace(_) => Py::new(py, (PyReplaceDec {}, base))?.into_py(py), DecoderWrapper::Replace(_) => Py::new(py, (PyReplaceDec {}, base))?.into_py(py),
DecoderWrapper::BPE(_) => Py::new(py, (PyBPEDecoder {}, base))?.into_py(py), DecoderWrapper::BPE(_) => Py::new(py, (PyBPEDecoder {}, base))?.into_py(py),
@ -238,6 +242,56 @@ impl PyByteFallbackDec {
} }
} }
/// Fuse Decoder
/// Fuse simply fuses every token into a single string.
/// This is the last step of decoding, this decoder exists only if
/// there is need to add other decoders *after* the fusion
#[pyclass(extends=PyDecoder, module = "tokenizers.decoders", name = "Fuse")]
#[pyo3(text_signature = "(self)")]
pub struct PyFuseDec {}
#[pymethods]
impl PyFuseDec {
#[new]
#[pyo3(signature = ())]
fn new() -> (Self, PyDecoder) {
(PyFuseDec {}, Fuse::new().into())
}
}
/// Strip normalizer
/// Strips n left characters of each token, or n right characters of each token
#[pyclass(extends=PyDecoder, module = "tokenizers.decoders", name = "Strip")]
#[pyo3(text_signature = "(self, left=0, right=0)")]
pub struct PyStrip {}
#[pymethods]
impl PyStrip {
#[getter]
fn get_left(self_: PyRef<Self>) -> usize {
getter!(self_, Strip, left)
}
#[setter]
fn set_left(self_: PyRef<Self>, left: usize) {
setter!(self_, Strip, left, left)
}
#[getter]
fn get_right(self_: PyRef<Self>) -> usize {
getter!(self_, Strip, right)
}
#[setter]
fn set_right(self_: PyRef<Self>, right: usize) {
setter!(self_, Strip, right, right)
}
#[new]
#[pyo3(signature = (left=0, right=0))]
fn new(left: usize, right: usize) -> (Self, PyDecoder) {
(PyStrip {}, Strip::new(left, right).into())
}
}
/// Metaspace Decoder /// Metaspace Decoder
/// ///
/// Args: /// Args:
@ -497,6 +551,8 @@ pub fn decoders(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<PyReplaceDec>()?; m.add_class::<PyReplaceDec>()?;
m.add_class::<PyWordPieceDec>()?; m.add_class::<PyWordPieceDec>()?;
m.add_class::<PyByteFallbackDec>()?; m.add_class::<PyByteFallbackDec>()?;
m.add_class::<PyFuseDec>()?;
m.add_class::<PyStrip>()?;
m.add_class::<PyMetaspaceDec>()?; m.add_class::<PyMetaspaceDec>()?;
m.add_class::<PyBPEDecoder>()?; m.add_class::<PyBPEDecoder>()?;
m.add_class::<PyCTCDecoder>()?; m.add_class::<PyCTCDecoder>()?;

View File

@ -13,6 +13,8 @@ from tokenizers.decoders import (
WordPiece, WordPiece,
ByteFallback, ByteFallback,
Replace, Replace,
Strip,
Fuse,
) )
@ -94,6 +96,30 @@ class TestByteFallback:
assert decoder.decode(["<0xE5>", "<0x8f>", "<0xab>", "a"]) == "叫a" assert decoder.decode(["<0xE5>", "<0x8f>", "<0xab>", "a"]) == "叫a"
class TestFuse:
def test_instantiate(self):
assert Fuse() is not None
assert isinstance(Fuse(), Decoder)
assert isinstance(Fuse(), Fuse)
assert isinstance(pickle.loads(pickle.dumps(Fuse())), Fuse)
def test_decoding(self):
decoder = Fuse()
assert decoder.decode(["My", " na", "me"]) == "My name"
class TestStrip:
def test_instantiate(self):
assert Strip(left=0, right=0) is not None
assert isinstance(Strip(left=0, right=0), Decoder)
assert isinstance(Strip(left=0, right=0), Strip)
assert isinstance(pickle.loads(pickle.dumps(Strip(left=0, right=0))), Strip)
def test_decoding(self):
decoder = Strip(left=1, right=0)
assert decoder.decode(["My", " na", "me"]) == "ynae"
class TestMetaspace: class TestMetaspace:
def test_instantiate(self): def test_instantiate(self):
assert Metaspace() is not None assert Metaspace() is not None

View File

@ -67,6 +67,7 @@ thiserror = "1.0.30"
fancy-regex = { version = "0.10", optional = true} fancy-regex = { version = "0.10", optional = true}
getrandom = { version = "0.2.6" } getrandom = { version = "0.2.6" }
esaxx-rs = { version = "0.1", default-features = false, features=[]} esaxx-rs = { version = "0.1", default-features = false, features=[]}
monostate = "0.1.5"
[features] [features]
default = ["progressbar", "http", "cli", "onig", "esaxx_fast"] default = ["progressbar", "http", "cli", "onig", "esaxx_fast"]

View File

@ -1,4 +1,5 @@
use crate::tokenizer::{Decoder, Result}; use crate::tokenizer::{Decoder, Result};
use monostate::MustBe;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -6,13 +7,17 @@ use serde::{Deserialize, Serialize};
/// ByteFallback is a simple trick which converts tokens looking like `<0x61>` /// ByteFallback is a simple trick which converts tokens looking like `<0x61>`
/// to pure bytes, and attempts to make them into a string. If the tokens /// to pure bytes, and attempts to make them into a string. If the tokens
/// cannot be decoded you will get <20> instead for each inconvertable byte token /// cannot be decoded you will get <20> instead for each inconvertable byte token
#[serde(tag = "type")]
#[non_exhaustive] #[non_exhaustive]
pub struct ByteFallback {} pub struct ByteFallback {
#[serde(rename = "type")]
type_: MustBe!("ByteFallback"),
}
impl ByteFallback { impl ByteFallback {
pub fn new() -> Self { pub fn new() -> Self {
Self {} Self {
type_: MustBe!("ByteFallback"),
}
} }
} }

View File

@ -0,0 +1,43 @@
use crate::tokenizer::{Decoder, Result};
use monostate::MustBe;
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Serialize, Deserialize, Default)]
/// Fuse simply fuses all tokens into one big string.
/// It's usually the last decoding step anyway, but this
/// decoder exists incase some decoders need to happen after that
/// step
#[non_exhaustive]
pub struct Fuse {
#[serde(rename = "type")]
type_: MustBe!("Fuse"),
}
impl Fuse {
pub fn new() -> Self {
Self {
type_: MustBe!("Fuse"),
}
}
}
impl Decoder for Fuse {
fn decode_chain(&self, tokens: Vec<String>) -> Result<Vec<String>> {
let new_string = tokens.join("");
Ok(vec![new_string])
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn decode() {
let decoder = Fuse::new();
let res = decoder
.decode_chain(vec!["Hey".into(), " friend!".into()])
.unwrap();
assert_eq!(res, vec!["Hey friend!"]);
}
}

View File

@ -1,7 +1,9 @@
pub mod bpe; pub mod bpe;
pub mod byte_fallback; pub mod byte_fallback;
pub mod ctc; pub mod ctc;
pub mod fuse;
pub mod sequence; pub mod sequence;
pub mod strip;
pub mod wordpiece; pub mod wordpiece;
// Re-export these as decoders // Re-export these as decoders
@ -13,7 +15,9 @@ use serde::{Deserialize, Serialize};
use crate::decoders::bpe::BPEDecoder; use crate::decoders::bpe::BPEDecoder;
use crate::decoders::byte_fallback::ByteFallback; use crate::decoders::byte_fallback::ByteFallback;
use crate::decoders::ctc::CTC; use crate::decoders::ctc::CTC;
use crate::decoders::fuse::Fuse;
use crate::decoders::sequence::Sequence; use crate::decoders::sequence::Sequence;
use crate::decoders::strip::Strip;
use crate::decoders::wordpiece::WordPiece; use crate::decoders::wordpiece::WordPiece;
use crate::normalizers::replace::Replace; use crate::normalizers::replace::Replace;
use crate::pre_tokenizers::byte_level::ByteLevel; use crate::pre_tokenizers::byte_level::ByteLevel;
@ -30,10 +34,8 @@ pub enum DecoderWrapper {
CTC(CTC), CTC(CTC),
Sequence(Sequence), Sequence(Sequence),
Replace(Replace), Replace(Replace),
// XXX: This is an untagged enum, which unfortunately means order Fuse(Fuse),
// is **CRITICAL**. We absolutely need to make sure order is correct. Strip(Strip),
// Since byte fallback is parameter free, is **has** to be last, and will
// unfortunately match pretty much everything.
ByteFallback(ByteFallback), ByteFallback(ByteFallback),
} }
@ -48,6 +50,8 @@ impl Decoder for DecoderWrapper {
Self::Sequence(seq) => seq.decode_chain(tokens), Self::Sequence(seq) => seq.decode_chain(tokens),
Self::Replace(seq) => seq.decode_chain(tokens), Self::Replace(seq) => seq.decode_chain(tokens),
Self::ByteFallback(bf) => bf.decode_chain(tokens), Self::ByteFallback(bf) => bf.decode_chain(tokens),
Self::Strip(bf) => bf.decode_chain(tokens),
Self::Fuse(bf) => bf.decode_chain(tokens),
} }
} }
} }
@ -55,6 +59,8 @@ impl Decoder for DecoderWrapper {
impl_enum_from!(BPEDecoder, DecoderWrapper, BPE); impl_enum_from!(BPEDecoder, DecoderWrapper, BPE);
impl_enum_from!(ByteLevel, DecoderWrapper, ByteLevel); impl_enum_from!(ByteLevel, DecoderWrapper, ByteLevel);
impl_enum_from!(ByteFallback, DecoderWrapper, ByteFallback); impl_enum_from!(ByteFallback, DecoderWrapper, ByteFallback);
impl_enum_from!(Fuse, DecoderWrapper, Fuse);
impl_enum_from!(Strip, DecoderWrapper, Strip);
impl_enum_from!(Metaspace, DecoderWrapper, Metaspace); impl_enum_from!(Metaspace, DecoderWrapper, Metaspace);
impl_enum_from!(WordPiece, DecoderWrapper, WordPiece); impl_enum_from!(WordPiece, DecoderWrapper, WordPiece);
impl_enum_from!(CTC, DecoderWrapper, CTC); impl_enum_from!(CTC, DecoderWrapper, CTC);
@ -72,4 +78,18 @@ mod tests {
let serialized = serde_json::to_string(&decoder).unwrap(); let serialized = serde_json::to_string(&decoder).unwrap();
assert_eq!(serialized, json); assert_eq!(serialized, json);
} }
#[test]
fn decoder_serialization_other_no_arg() {
let json = r#"{"type":"Sequence","decoders":[{"type":"Fuse"},{"type":"Metaspace","replacement":"▁","add_prefix_space":true}]}"#;
let decoder: DecoderWrapper = serde_json::from_str(json).unwrap();
let serialized = serde_json::to_string(&decoder).unwrap();
assert_eq!(serialized, json);
}
#[test]
fn decoder_serialization_no_decode() {
let json = r#"{"type":"Sequence","decoders":[{},{"type":"Metaspace","replacement":"▁","add_prefix_space":true}]}"#;
assert!(serde_json::from_str::<DecoderWrapper>(json).is_err());
}
} }

View File

@ -0,0 +1,49 @@
use crate::tokenizer::{Decoder, Result};
use serde::{Deserialize, Serialize};
#[derive(Deserialize, Clone, Debug, Serialize, Default)]
/// Strip is a simple trick which converts tokens looking like `<0x61>`
/// to pure bytes, and attempts to make them into a string. If the tokens
/// cannot be decoded you will get <20> instead for each inconvertable byte token
#[serde(tag = "type")]
#[non_exhaustive]
pub struct Strip {
pub left: usize,
pub right: usize,
}
impl Strip {
pub fn new(left: usize, right: usize) -> Self {
Self { left, right }
}
}
impl Decoder for Strip {
fn decode_chain(&self, tokens: Vec<String>) -> Result<Vec<String>> {
Ok(tokens
.into_iter()
.map(|token| token[self.left..token.len() - self.right].to_string())
.collect())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn decode() {
let decoder = Strip::new(1, 0);
let res = decoder
.decode_chain(vec!["Hey".into(), " friend!".into()])
.unwrap();
assert_eq!(res, vec!["ey", "friend!"]);
let decoder = Strip::new(0, 1);
let res = decoder
.decode_chain(vec!["Hey".into(), " friend!".into()])
.unwrap();
assert_eq!(res, vec!["He", " friend"]);
}
}