Adding Replace to decoder (to undo the Replace Normalizer for (#1195)

Metaspace split).
This commit is contained in:
Nicolas Patry
2023-03-23 23:43:47 +01:00
committed by GitHub
parent 178e294a6a
commit 250d46c676
10 changed files with 135 additions and 1 deletions

View File

@ -12,6 +12,13 @@ interface Decoder {
*/
export function byteLevelDecoder(): Decoder;
/**
* Instantiate a new Replace Decoder
* @param [pattern] The pattern to replace
* @param [content] The replacement.
*/
export function replaceDecoder(pattern: string, content: string): Decoder;
/**
* Instantiate a new WordPiece Decoder
* @param [prefix='##'] The prefix to use for subwords that are not a beginning-of-word

View File

@ -2,6 +2,7 @@ const native = require("./native");
module.exports = {
byteLevelDecoder: native.decoders_ByteLevel,
replaceDecoder: native.decoders_Replace,
wordPieceDecoder: native.decoders_WordPiece,
byteFallbackDecoder: native.decoders_ByteFallback,
metaspaceDecoder: native.decoders_Metaspace,

View File

@ -3,6 +3,7 @@ import {
byteFallbackDecoder,
ctcDecoder,
metaspaceDecoder,
replaceDecoder,
sequenceDecoder,
wordPieceDecoder,
} from "./decoders";
@ -44,6 +45,12 @@ describe("byteFallbackDecoder", () => {
});
});
describe("replaceDecoder", () => {
it("can decode arrays of strings", () => {
expect(replaceDecoder("_", " ").decode(["Hello", "_Hello"])).toEqual("Hello Hello");
});
});
describe("metaspaceDecoder", () => {
it("accepts `undefined` as first parameter", () => {
expect(metaspaceDecoder(undefined)).toBeDefined();

View File

@ -57,6 +57,20 @@ fn byte_level(mut cx: FunctionContext) -> JsResult<JsDecoder> {
Ok(decoder)
}
/// replace()
fn replace(mut cx: FunctionContext) -> JsResult<JsDecoder> {
let pattern: String = cx.extract::<String>(0)?;
let content: String = cx.extract::<String>(1)?;
let mut decoder = JsDecoder::new::<_, JsDecoder, _>(&mut cx, vec![])?;
let guard = cx.lock();
decoder.borrow_mut(&guard).decoder = Some(Arc::new(
tk::normalizers::replace::Replace::new(pattern, content)
.map_err(|e| Error(e.to_string()))?
.into(),
));
Ok(decoder)
}
/// wordpiece(prefix: String = "##", cleanup: bool)
fn wordpiece(mut cx: FunctionContext) -> JsResult<JsDecoder> {
let prefix = cx
@ -156,6 +170,7 @@ fn sequence(mut cx: FunctionContext) -> JsResult<JsDecoder> {
/// Register everything here
pub fn register(m: &mut ModuleContext, prefix: &str) -> NeonResult<()> {
m.export_function(&format!("{}_ByteLevel", prefix), byte_level)?;
m.export_function(&format!("{}_Replace", prefix), replace)?;
m.export_function(&format!("{}_WordPiece", prefix), wordpiece)?;
m.export_function(&format!("{}_ByteFallback", prefix), byte_fallback)?;
m.export_function(&format!("{}_Metaspace", prefix), metaspace)?;

View File

@ -3,6 +3,7 @@ from .. import decoders
Decoder = decoders.Decoder
ByteLevel = decoders.ByteLevel
Replace = decoders.Replace
WordPiece = decoders.WordPiece
ByteFallback = decoders.ByteFallback
Metaspace = decoders.Metaspace

View File

@ -150,6 +150,29 @@ class Metaspace(Decoder):
"""
pass
class Replace(Decoder):
"""
Replace Decoder
This decoder is to be used in tandem with the :class:`~tokenizers.pre_tokenizers.Replace`
:class:`~tokenizers.pre_tokenizers.PreTokenizer`.
"""
def __init__(self, pattern, content):
pass
def decode(self, tokens):
"""
Decode the given list of tokens to a final string
Args:
tokens (:obj:`List[str]`):
The list of tokens to decode
Returns:
:obj:`str`: The decoded string
"""
pass
class Sequence(Decoder):
"""
Sequence Decoder

View File

@ -1,6 +1,7 @@
use std::sync::{Arc, RwLock};
use crate::utils::PyChar;
use crate::utils::PyPattern;
use pyo3::exceptions;
use pyo3::prelude::*;
use pyo3::types::*;
@ -14,6 +15,7 @@ use tk::decoders::metaspace::Metaspace;
use tk::decoders::sequence::Sequence;
use tk::decoders::wordpiece::WordPiece;
use tk::decoders::DecoderWrapper;
use tk::normalizers::replace::Replace;
use tk::Decoder;
use tokenizers as tk;
@ -46,6 +48,7 @@ impl PyDecoder {
Py::new(py, (PyByteFallbackDec {}, base))?.into_py(py)
}
DecoderWrapper::ByteLevel(_) => Py::new(py, (PyByteLevelDec {}, base))?.into_py(py),
DecoderWrapper::Replace(_) => Py::new(py, (PyReplaceDec {}, base))?.into_py(py),
DecoderWrapper::BPE(_) => Py::new(py, (PyBPEDecoder {}, base))?.into_py(py),
DecoderWrapper::CTC(_) => Py::new(py, (PyCTCDecoder {}, base))?.into_py(py),
DecoderWrapper::Sequence(_) => {
@ -159,6 +162,24 @@ impl PyByteLevelDec {
}
}
/// Replace Decoder
///
/// This decoder is to be used in tandem with the :class:`~tokenizers.pre_tokenizers.Replace`
/// :class:`~tokenizers.pre_tokenizers.PreTokenizer`.
#[pyclass(extends=PyDecoder, module = "tokenizers.decoders", name = "Replace")]
#[pyo3(text_signature = "(self, pattern, content)")]
pub struct PyReplaceDec {}
#[pymethods]
impl PyReplaceDec {
#[new]
fn new(pattern: PyPattern, content: String) -> PyResult<(Self, PyDecoder)> {
Ok((
PyReplaceDec {},
ToPyResult(Replace::new(pattern, content)).into_py()?.into(),
))
}
}
/// WordPiece Decoder
///
/// Args:
@ -473,6 +494,7 @@ impl Decoder for PyDecoderWrapper {
pub fn decoders(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<PyDecoder>()?;
m.add_class::<PyByteLevelDec>()?;
m.add_class::<PyReplaceDec>()?;
m.add_class::<PyWordPieceDec>()?;
m.add_class::<PyByteFallbackDec>()?;
m.add_class::<PyMetaspaceDec>()?;

View File

@ -3,7 +3,17 @@ import pickle
import pytest
from tokenizers.decoders import CTC, BPEDecoder, ByteLevel, Decoder, Metaspace, Sequence, WordPiece, ByteFallback
from tokenizers.decoders import (
CTC,
BPEDecoder,
ByteLevel,
Decoder,
Metaspace,
Sequence,
WordPiece,
ByteFallback,
Replace,
)
class TestByteLevel:
@ -24,6 +34,18 @@ class TestByteLevel:
assert isinstance(reloaded, ByteLevel)
class TestReplace:
def test_instantiate(self):
assert Replace("_", " ") is not None
assert isinstance(Replace("_", " "), Decoder)
assert isinstance(Replace("_", " "), Replace)
# assert isinstance(pickle.loads(pickle.dumps(Replace("_", " "))), Replace)
def test_decoding(self):
decoder = Replace("_", " ")
assert decoder.decode(["My", "_name", "_is", "_John"]) == "My name is John"
class TestWordPiece:
def test_instantiate(self):
assert WordPiece() is not None

View File

@ -15,6 +15,7 @@ use crate::decoders::byte_fallback::ByteFallback;
use crate::decoders::ctc::CTC;
use crate::decoders::sequence::Sequence;
use crate::decoders::wordpiece::WordPiece;
use crate::normalizers::replace::Replace;
use crate::pre_tokenizers::byte_level::ByteLevel;
use crate::pre_tokenizers::metaspace::Metaspace;
use crate::{Decoder, Result};
@ -28,6 +29,7 @@ pub enum DecoderWrapper {
Metaspace(Metaspace),
CTC(CTC),
Sequence(Sequence),
Replace(Replace),
// XXX: This is an untagged enum, which unfortunately means order
// is **CRITICAL**. We absolutely need to make sure order is correct.
// Since byte fallback is parameter free, is **has** to be last, and will
@ -44,6 +46,7 @@ impl Decoder for DecoderWrapper {
Self::WordPiece(wp) => wp.decode_chain(tokens),
Self::CTC(ctc) => ctc.decode_chain(tokens),
Self::Sequence(seq) => seq.decode_chain(tokens),
Self::Replace(seq) => seq.decode_chain(tokens),
Self::ByteFallback(bf) => bf.decode_chain(tokens),
}
}
@ -56,6 +59,7 @@ impl_enum_from!(Metaspace, DecoderWrapper, Metaspace);
impl_enum_from!(WordPiece, DecoderWrapper, WordPiece);
impl_enum_from!(CTC, DecoderWrapper, CTC);
impl_enum_from!(Sequence, DecoderWrapper, Sequence);
impl_enum_from!(Replace, DecoderWrapper, Replace);
#[cfg(test)]
mod tests {

View File

@ -1,3 +1,5 @@
use crate::tokenizer::pattern::Pattern;
use crate::tokenizer::Decoder;
use crate::tokenizer::{NormalizedString, Normalizer, Result};
use crate::utils::SysRegex;
use serde::{Deserialize, Serialize};
@ -83,6 +85,26 @@ impl Normalizer for Replace {
}
}
impl Decoder for Replace {
fn decode_chain(&self, tokens: Vec<String>) -> Result<Vec<String>> {
tokens
.into_iter()
.map(|token| -> Result<String> {
let mut new_token = "".to_string();
for ((start, stop), is_match) in (&self.regex).find_matches(&token)? {
if is_match {
new_token.push_str(&self.content);
} else {
new_token.push_str(&token[start..stop]);
}
}
Ok(new_token)
})
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
@ -124,4 +146,14 @@ mod tests {
assert_eq!(serde_json::to_string(&replace).unwrap(), replace_s);
assert_eq!(serde_json::from_str::<Replace>(replace_s).unwrap(), replace);
}
#[test]
fn test_replace_decode() {
let original = vec!["hello".to_string(), "_hello".to_string()];
let replace = Replace::new("_", " ").unwrap();
assert_eq!(
replace.decode_chain(original).unwrap(),
vec!["hello", " hello"]
);
}
}