diff --git a/bindings/node/lib/bindings/decoders.d.ts b/bindings/node/lib/bindings/decoders.d.ts index b00e5f87..d6e809b8 100644 --- a/bindings/node/lib/bindings/decoders.d.ts +++ b/bindings/node/lib/bindings/decoders.d.ts @@ -12,6 +12,13 @@ interface Decoder { */ export function byteLevelDecoder(): Decoder; +/** + * Instantiate a new Replace Decoder + * @param [pattern] The pattern to replace + * @param [content] The replacement. + */ +export function replaceDecoder(pattern: string, content: string): Decoder; + /** * Instantiate a new WordPiece Decoder * @param [prefix='##'] The prefix to use for subwords that are not a beginning-of-word diff --git a/bindings/node/lib/bindings/decoders.js b/bindings/node/lib/bindings/decoders.js index 6c1a66c8..9d10b07f 100644 --- a/bindings/node/lib/bindings/decoders.js +++ b/bindings/node/lib/bindings/decoders.js @@ -2,6 +2,7 @@ const native = require("./native"); module.exports = { byteLevelDecoder: native.decoders_ByteLevel, + replaceDecoder: native.decoders_Replace, wordPieceDecoder: native.decoders_WordPiece, byteFallbackDecoder: native.decoders_ByteFallback, metaspaceDecoder: native.decoders_Metaspace, diff --git a/bindings/node/lib/bindings/decoders.test.ts b/bindings/node/lib/bindings/decoders.test.ts index ec80040e..7e31b576 100644 --- a/bindings/node/lib/bindings/decoders.test.ts +++ b/bindings/node/lib/bindings/decoders.test.ts @@ -3,6 +3,7 @@ import { byteFallbackDecoder, ctcDecoder, metaspaceDecoder, + replaceDecoder, sequenceDecoder, wordPieceDecoder, } from "./decoders"; @@ -44,6 +45,12 @@ describe("byteFallbackDecoder", () => { }); }); +describe("replaceDecoder", () => { + it("can decode arrays of strings", () => { + expect(replaceDecoder("_", " ").decode(["Hello", "_Hello"])).toEqual("Hello Hello"); + }); +}); + describe("metaspaceDecoder", () => { it("accepts `undefined` as first parameter", () => { expect(metaspaceDecoder(undefined)).toBeDefined(); diff --git a/bindings/node/native/src/decoders.rs b/bindings/node/native/src/decoders.rs index ddef37ad..c8433b8d 100644 --- a/bindings/node/native/src/decoders.rs +++ b/bindings/node/native/src/decoders.rs @@ -57,6 +57,20 @@ fn byte_level(mut cx: FunctionContext) -> JsResult { Ok(decoder) } +/// replace() +fn replace(mut cx: FunctionContext) -> JsResult { + let pattern: String = cx.extract::(0)?; + let content: String = cx.extract::(1)?; + let mut decoder = JsDecoder::new::<_, JsDecoder, _>(&mut cx, vec![])?; + let guard = cx.lock(); + decoder.borrow_mut(&guard).decoder = Some(Arc::new( + tk::normalizers::replace::Replace::new(pattern, content) + .map_err(|e| Error(e.to_string()))? + .into(), + )); + Ok(decoder) +} + /// wordpiece(prefix: String = "##", cleanup: bool) fn wordpiece(mut cx: FunctionContext) -> JsResult { let prefix = cx @@ -156,6 +170,7 @@ fn sequence(mut cx: FunctionContext) -> JsResult { /// Register everything here pub fn register(m: &mut ModuleContext, prefix: &str) -> NeonResult<()> { m.export_function(&format!("{}_ByteLevel", prefix), byte_level)?; + m.export_function(&format!("{}_Replace", prefix), replace)?; m.export_function(&format!("{}_WordPiece", prefix), wordpiece)?; m.export_function(&format!("{}_ByteFallback", prefix), byte_fallback)?; m.export_function(&format!("{}_Metaspace", prefix), metaspace)?; diff --git a/bindings/python/py_src/tokenizers/decoders/__init__.py b/bindings/python/py_src/tokenizers/decoders/__init__.py index bddc9be2..44c906e4 100644 --- a/bindings/python/py_src/tokenizers/decoders/__init__.py +++ b/bindings/python/py_src/tokenizers/decoders/__init__.py @@ -3,6 +3,7 @@ from .. import decoders Decoder = decoders.Decoder ByteLevel = decoders.ByteLevel +Replace = decoders.Replace WordPiece = decoders.WordPiece ByteFallback = decoders.ByteFallback Metaspace = decoders.Metaspace diff --git a/bindings/python/py_src/tokenizers/decoders/__init__.pyi b/bindings/python/py_src/tokenizers/decoders/__init__.pyi index 06150e98..ea174e92 100644 --- a/bindings/python/py_src/tokenizers/decoders/__init__.pyi +++ b/bindings/python/py_src/tokenizers/decoders/__init__.pyi @@ -150,6 +150,29 @@ class Metaspace(Decoder): """ pass +class Replace(Decoder): + """ + Replace Decoder + + This decoder is to be used in tandem with the :class:`~tokenizers.pre_tokenizers.Replace` + :class:`~tokenizers.pre_tokenizers.PreTokenizer`. + """ + + def __init__(self, pattern, content): + pass + def decode(self, tokens): + """ + Decode the given list of tokens to a final string + + Args: + tokens (:obj:`List[str]`): + The list of tokens to decode + + Returns: + :obj:`str`: The decoded string + """ + pass + class Sequence(Decoder): """ Sequence Decoder diff --git a/bindings/python/src/decoders.rs b/bindings/python/src/decoders.rs index e2405690..85d76ec0 100644 --- a/bindings/python/src/decoders.rs +++ b/bindings/python/src/decoders.rs @@ -1,6 +1,7 @@ use std::sync::{Arc, RwLock}; use crate::utils::PyChar; +use crate::utils::PyPattern; use pyo3::exceptions; use pyo3::prelude::*; use pyo3::types::*; @@ -14,6 +15,7 @@ use tk::decoders::metaspace::Metaspace; use tk::decoders::sequence::Sequence; use tk::decoders::wordpiece::WordPiece; use tk::decoders::DecoderWrapper; +use tk::normalizers::replace::Replace; use tk::Decoder; use tokenizers as tk; @@ -46,6 +48,7 @@ impl PyDecoder { Py::new(py, (PyByteFallbackDec {}, base))?.into_py(py) } DecoderWrapper::ByteLevel(_) => Py::new(py, (PyByteLevelDec {}, base))?.into_py(py), + DecoderWrapper::Replace(_) => Py::new(py, (PyReplaceDec {}, base))?.into_py(py), DecoderWrapper::BPE(_) => Py::new(py, (PyBPEDecoder {}, base))?.into_py(py), DecoderWrapper::CTC(_) => Py::new(py, (PyCTCDecoder {}, base))?.into_py(py), DecoderWrapper::Sequence(_) => { @@ -159,6 +162,24 @@ impl PyByteLevelDec { } } +/// Replace Decoder +/// +/// This decoder is to be used in tandem with the :class:`~tokenizers.pre_tokenizers.Replace` +/// :class:`~tokenizers.pre_tokenizers.PreTokenizer`. +#[pyclass(extends=PyDecoder, module = "tokenizers.decoders", name = "Replace")] +#[pyo3(text_signature = "(self, pattern, content)")] +pub struct PyReplaceDec {} +#[pymethods] +impl PyReplaceDec { + #[new] + fn new(pattern: PyPattern, content: String) -> PyResult<(Self, PyDecoder)> { + Ok(( + PyReplaceDec {}, + ToPyResult(Replace::new(pattern, content)).into_py()?.into(), + )) + } +} + /// WordPiece Decoder /// /// Args: @@ -473,6 +494,7 @@ impl Decoder for PyDecoderWrapper { pub fn decoders(_py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/bindings/python/tests/bindings/test_decoders.py b/bindings/python/tests/bindings/test_decoders.py index f404f922..f5ded6af 100644 --- a/bindings/python/tests/bindings/test_decoders.py +++ b/bindings/python/tests/bindings/test_decoders.py @@ -3,7 +3,17 @@ import pickle import pytest -from tokenizers.decoders import CTC, BPEDecoder, ByteLevel, Decoder, Metaspace, Sequence, WordPiece, ByteFallback +from tokenizers.decoders import ( + CTC, + BPEDecoder, + ByteLevel, + Decoder, + Metaspace, + Sequence, + WordPiece, + ByteFallback, + Replace, +) class TestByteLevel: @@ -24,6 +34,18 @@ class TestByteLevel: assert isinstance(reloaded, ByteLevel) +class TestReplace: + def test_instantiate(self): + assert Replace("_", " ") is not None + assert isinstance(Replace("_", " "), Decoder) + assert isinstance(Replace("_", " "), Replace) + # assert isinstance(pickle.loads(pickle.dumps(Replace("_", " "))), Replace) + + def test_decoding(self): + decoder = Replace("_", " ") + assert decoder.decode(["My", "_name", "_is", "_John"]) == "My name is John" + + class TestWordPiece: def test_instantiate(self): assert WordPiece() is not None diff --git a/tokenizers/src/decoders/mod.rs b/tokenizers/src/decoders/mod.rs index 8e94efb7..9c85928c 100644 --- a/tokenizers/src/decoders/mod.rs +++ b/tokenizers/src/decoders/mod.rs @@ -15,6 +15,7 @@ use crate::decoders::byte_fallback::ByteFallback; use crate::decoders::ctc::CTC; use crate::decoders::sequence::Sequence; use crate::decoders::wordpiece::WordPiece; +use crate::normalizers::replace::Replace; use crate::pre_tokenizers::byte_level::ByteLevel; use crate::pre_tokenizers::metaspace::Metaspace; use crate::{Decoder, Result}; @@ -28,6 +29,7 @@ pub enum DecoderWrapper { Metaspace(Metaspace), CTC(CTC), Sequence(Sequence), + Replace(Replace), // XXX: This is an untagged enum, which unfortunately means order // is **CRITICAL**. We absolutely need to make sure order is correct. // Since byte fallback is parameter free, is **has** to be last, and will @@ -44,6 +46,7 @@ impl Decoder for DecoderWrapper { Self::WordPiece(wp) => wp.decode_chain(tokens), Self::CTC(ctc) => ctc.decode_chain(tokens), Self::Sequence(seq) => seq.decode_chain(tokens), + Self::Replace(seq) => seq.decode_chain(tokens), Self::ByteFallback(bf) => bf.decode_chain(tokens), } } @@ -56,6 +59,7 @@ impl_enum_from!(Metaspace, DecoderWrapper, Metaspace); impl_enum_from!(WordPiece, DecoderWrapper, WordPiece); impl_enum_from!(CTC, DecoderWrapper, CTC); impl_enum_from!(Sequence, DecoderWrapper, Sequence); +impl_enum_from!(Replace, DecoderWrapper, Replace); #[cfg(test)] mod tests { diff --git a/tokenizers/src/normalizers/replace.rs b/tokenizers/src/normalizers/replace.rs index fb42222f..cdd4a420 100644 --- a/tokenizers/src/normalizers/replace.rs +++ b/tokenizers/src/normalizers/replace.rs @@ -1,3 +1,5 @@ +use crate::tokenizer::pattern::Pattern; +use crate::tokenizer::Decoder; use crate::tokenizer::{NormalizedString, Normalizer, Result}; use crate::utils::SysRegex; use serde::{Deserialize, Serialize}; @@ -83,6 +85,26 @@ impl Normalizer for Replace { } } +impl Decoder for Replace { + fn decode_chain(&self, tokens: Vec) -> Result> { + tokens + .into_iter() + .map(|token| -> Result { + let mut new_token = "".to_string(); + + for ((start, stop), is_match) in (&self.regex).find_matches(&token)? { + if is_match { + new_token.push_str(&self.content); + } else { + new_token.push_str(&token[start..stop]); + } + } + Ok(new_token) + }) + .collect() + } +} + #[cfg(test)] mod tests { use super::*; @@ -124,4 +146,14 @@ mod tests { assert_eq!(serde_json::to_string(&replace).unwrap(), replace_s); assert_eq!(serde_json::from_str::(replace_s).unwrap(), replace); } + + #[test] + fn test_replace_decode() { + let original = vec!["hello".to_string(), "_hello".to_string()]; + let replace = Replace::new("_", " ").unwrap(); + assert_eq!( + replace.decode_chain(original).unwrap(), + vec!["hello", " hello"] + ); + } }