mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Adding Replace
to decoder (to undo the Replace Normalizer for (#1195)
Metaspace split).
This commit is contained in:
7
bindings/node/lib/bindings/decoders.d.ts
vendored
7
bindings/node/lib/bindings/decoders.d.ts
vendored
@ -12,6 +12,13 @@ interface Decoder {
|
||||
*/
|
||||
export function byteLevelDecoder(): Decoder;
|
||||
|
||||
/**
|
||||
* Instantiate a new Replace Decoder
|
||||
* @param [pattern] The pattern to replace
|
||||
* @param [content] The replacement.
|
||||
*/
|
||||
export function replaceDecoder(pattern: string, content: string): Decoder;
|
||||
|
||||
/**
|
||||
* Instantiate a new WordPiece Decoder
|
||||
* @param [prefix='##'] The prefix to use for subwords that are not a beginning-of-word
|
||||
|
@ -2,6 +2,7 @@ const native = require("./native");
|
||||
|
||||
module.exports = {
|
||||
byteLevelDecoder: native.decoders_ByteLevel,
|
||||
replaceDecoder: native.decoders_Replace,
|
||||
wordPieceDecoder: native.decoders_WordPiece,
|
||||
byteFallbackDecoder: native.decoders_ByteFallback,
|
||||
metaspaceDecoder: native.decoders_Metaspace,
|
||||
|
@ -3,6 +3,7 @@ import {
|
||||
byteFallbackDecoder,
|
||||
ctcDecoder,
|
||||
metaspaceDecoder,
|
||||
replaceDecoder,
|
||||
sequenceDecoder,
|
||||
wordPieceDecoder,
|
||||
} from "./decoders";
|
||||
@ -44,6 +45,12 @@ describe("byteFallbackDecoder", () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe("replaceDecoder", () => {
|
||||
it("can decode arrays of strings", () => {
|
||||
expect(replaceDecoder("_", " ").decode(["Hello", "_Hello"])).toEqual("Hello Hello");
|
||||
});
|
||||
});
|
||||
|
||||
describe("metaspaceDecoder", () => {
|
||||
it("accepts `undefined` as first parameter", () => {
|
||||
expect(metaspaceDecoder(undefined)).toBeDefined();
|
||||
|
@ -57,6 +57,20 @@ fn byte_level(mut cx: FunctionContext) -> JsResult<JsDecoder> {
|
||||
Ok(decoder)
|
||||
}
|
||||
|
||||
/// replace()
|
||||
fn replace(mut cx: FunctionContext) -> JsResult<JsDecoder> {
|
||||
let pattern: String = cx.extract::<String>(0)?;
|
||||
let content: String = cx.extract::<String>(1)?;
|
||||
let mut decoder = JsDecoder::new::<_, JsDecoder, _>(&mut cx, vec![])?;
|
||||
let guard = cx.lock();
|
||||
decoder.borrow_mut(&guard).decoder = Some(Arc::new(
|
||||
tk::normalizers::replace::Replace::new(pattern, content)
|
||||
.map_err(|e| Error(e.to_string()))?
|
||||
.into(),
|
||||
));
|
||||
Ok(decoder)
|
||||
}
|
||||
|
||||
/// wordpiece(prefix: String = "##", cleanup: bool)
|
||||
fn wordpiece(mut cx: FunctionContext) -> JsResult<JsDecoder> {
|
||||
let prefix = cx
|
||||
@ -156,6 +170,7 @@ fn sequence(mut cx: FunctionContext) -> JsResult<JsDecoder> {
|
||||
/// Register everything here
|
||||
pub fn register(m: &mut ModuleContext, prefix: &str) -> NeonResult<()> {
|
||||
m.export_function(&format!("{}_ByteLevel", prefix), byte_level)?;
|
||||
m.export_function(&format!("{}_Replace", prefix), replace)?;
|
||||
m.export_function(&format!("{}_WordPiece", prefix), wordpiece)?;
|
||||
m.export_function(&format!("{}_ByteFallback", prefix), byte_fallback)?;
|
||||
m.export_function(&format!("{}_Metaspace", prefix), metaspace)?;
|
||||
|
@ -3,6 +3,7 @@ from .. import decoders
|
||||
|
||||
Decoder = decoders.Decoder
|
||||
ByteLevel = decoders.ByteLevel
|
||||
Replace = decoders.Replace
|
||||
WordPiece = decoders.WordPiece
|
||||
ByteFallback = decoders.ByteFallback
|
||||
Metaspace = decoders.Metaspace
|
||||
|
@ -150,6 +150,29 @@ class Metaspace(Decoder):
|
||||
"""
|
||||
pass
|
||||
|
||||
class Replace(Decoder):
|
||||
"""
|
||||
Replace Decoder
|
||||
|
||||
This decoder is to be used in tandem with the :class:`~tokenizers.pre_tokenizers.Replace`
|
||||
:class:`~tokenizers.pre_tokenizers.PreTokenizer`.
|
||||
"""
|
||||
|
||||
def __init__(self, pattern, content):
|
||||
pass
|
||||
def decode(self, tokens):
|
||||
"""
|
||||
Decode the given list of tokens to a final string
|
||||
|
||||
Args:
|
||||
tokens (:obj:`List[str]`):
|
||||
The list of tokens to decode
|
||||
|
||||
Returns:
|
||||
:obj:`str`: The decoded string
|
||||
"""
|
||||
pass
|
||||
|
||||
class Sequence(Decoder):
|
||||
"""
|
||||
Sequence Decoder
|
||||
|
@ -1,6 +1,7 @@
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
use crate::utils::PyChar;
|
||||
use crate::utils::PyPattern;
|
||||
use pyo3::exceptions;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::*;
|
||||
@ -14,6 +15,7 @@ use tk::decoders::metaspace::Metaspace;
|
||||
use tk::decoders::sequence::Sequence;
|
||||
use tk::decoders::wordpiece::WordPiece;
|
||||
use tk::decoders::DecoderWrapper;
|
||||
use tk::normalizers::replace::Replace;
|
||||
use tk::Decoder;
|
||||
use tokenizers as tk;
|
||||
|
||||
@ -46,6 +48,7 @@ impl PyDecoder {
|
||||
Py::new(py, (PyByteFallbackDec {}, base))?.into_py(py)
|
||||
}
|
||||
DecoderWrapper::ByteLevel(_) => Py::new(py, (PyByteLevelDec {}, base))?.into_py(py),
|
||||
DecoderWrapper::Replace(_) => Py::new(py, (PyReplaceDec {}, base))?.into_py(py),
|
||||
DecoderWrapper::BPE(_) => Py::new(py, (PyBPEDecoder {}, base))?.into_py(py),
|
||||
DecoderWrapper::CTC(_) => Py::new(py, (PyCTCDecoder {}, base))?.into_py(py),
|
||||
DecoderWrapper::Sequence(_) => {
|
||||
@ -159,6 +162,24 @@ impl PyByteLevelDec {
|
||||
}
|
||||
}
|
||||
|
||||
/// Replace Decoder
|
||||
///
|
||||
/// This decoder is to be used in tandem with the :class:`~tokenizers.pre_tokenizers.Replace`
|
||||
/// :class:`~tokenizers.pre_tokenizers.PreTokenizer`.
|
||||
#[pyclass(extends=PyDecoder, module = "tokenizers.decoders", name = "Replace")]
|
||||
#[pyo3(text_signature = "(self, pattern, content)")]
|
||||
pub struct PyReplaceDec {}
|
||||
#[pymethods]
|
||||
impl PyReplaceDec {
|
||||
#[new]
|
||||
fn new(pattern: PyPattern, content: String) -> PyResult<(Self, PyDecoder)> {
|
||||
Ok((
|
||||
PyReplaceDec {},
|
||||
ToPyResult(Replace::new(pattern, content)).into_py()?.into(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
/// WordPiece Decoder
|
||||
///
|
||||
/// Args:
|
||||
@ -473,6 +494,7 @@ impl Decoder for PyDecoderWrapper {
|
||||
pub fn decoders(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||
m.add_class::<PyDecoder>()?;
|
||||
m.add_class::<PyByteLevelDec>()?;
|
||||
m.add_class::<PyReplaceDec>()?;
|
||||
m.add_class::<PyWordPieceDec>()?;
|
||||
m.add_class::<PyByteFallbackDec>()?;
|
||||
m.add_class::<PyMetaspaceDec>()?;
|
||||
|
@ -3,7 +3,17 @@ import pickle
|
||||
|
||||
import pytest
|
||||
|
||||
from tokenizers.decoders import CTC, BPEDecoder, ByteLevel, Decoder, Metaspace, Sequence, WordPiece, ByteFallback
|
||||
from tokenizers.decoders import (
|
||||
CTC,
|
||||
BPEDecoder,
|
||||
ByteLevel,
|
||||
Decoder,
|
||||
Metaspace,
|
||||
Sequence,
|
||||
WordPiece,
|
||||
ByteFallback,
|
||||
Replace,
|
||||
)
|
||||
|
||||
|
||||
class TestByteLevel:
|
||||
@ -24,6 +34,18 @@ class TestByteLevel:
|
||||
assert isinstance(reloaded, ByteLevel)
|
||||
|
||||
|
||||
class TestReplace:
|
||||
def test_instantiate(self):
|
||||
assert Replace("_", " ") is not None
|
||||
assert isinstance(Replace("_", " "), Decoder)
|
||||
assert isinstance(Replace("_", " "), Replace)
|
||||
# assert isinstance(pickle.loads(pickle.dumps(Replace("_", " "))), Replace)
|
||||
|
||||
def test_decoding(self):
|
||||
decoder = Replace("_", " ")
|
||||
assert decoder.decode(["My", "_name", "_is", "_John"]) == "My name is John"
|
||||
|
||||
|
||||
class TestWordPiece:
|
||||
def test_instantiate(self):
|
||||
assert WordPiece() is not None
|
||||
|
@ -15,6 +15,7 @@ use crate::decoders::byte_fallback::ByteFallback;
|
||||
use crate::decoders::ctc::CTC;
|
||||
use crate::decoders::sequence::Sequence;
|
||||
use crate::decoders::wordpiece::WordPiece;
|
||||
use crate::normalizers::replace::Replace;
|
||||
use crate::pre_tokenizers::byte_level::ByteLevel;
|
||||
use crate::pre_tokenizers::metaspace::Metaspace;
|
||||
use crate::{Decoder, Result};
|
||||
@ -28,6 +29,7 @@ pub enum DecoderWrapper {
|
||||
Metaspace(Metaspace),
|
||||
CTC(CTC),
|
||||
Sequence(Sequence),
|
||||
Replace(Replace),
|
||||
// XXX: This is an untagged enum, which unfortunately means order
|
||||
// is **CRITICAL**. We absolutely need to make sure order is correct.
|
||||
// Since byte fallback is parameter free, is **has** to be last, and will
|
||||
@ -44,6 +46,7 @@ impl Decoder for DecoderWrapper {
|
||||
Self::WordPiece(wp) => wp.decode_chain(tokens),
|
||||
Self::CTC(ctc) => ctc.decode_chain(tokens),
|
||||
Self::Sequence(seq) => seq.decode_chain(tokens),
|
||||
Self::Replace(seq) => seq.decode_chain(tokens),
|
||||
Self::ByteFallback(bf) => bf.decode_chain(tokens),
|
||||
}
|
||||
}
|
||||
@ -56,6 +59,7 @@ impl_enum_from!(Metaspace, DecoderWrapper, Metaspace);
|
||||
impl_enum_from!(WordPiece, DecoderWrapper, WordPiece);
|
||||
impl_enum_from!(CTC, DecoderWrapper, CTC);
|
||||
impl_enum_from!(Sequence, DecoderWrapper, Sequence);
|
||||
impl_enum_from!(Replace, DecoderWrapper, Replace);
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
@ -1,3 +1,5 @@
|
||||
use crate::tokenizer::pattern::Pattern;
|
||||
use crate::tokenizer::Decoder;
|
||||
use crate::tokenizer::{NormalizedString, Normalizer, Result};
|
||||
use crate::utils::SysRegex;
|
||||
use serde::{Deserialize, Serialize};
|
||||
@ -83,6 +85,26 @@ impl Normalizer for Replace {
|
||||
}
|
||||
}
|
||||
|
||||
impl Decoder for Replace {
|
||||
fn decode_chain(&self, tokens: Vec<String>) -> Result<Vec<String>> {
|
||||
tokens
|
||||
.into_iter()
|
||||
.map(|token| -> Result<String> {
|
||||
let mut new_token = "".to_string();
|
||||
|
||||
for ((start, stop), is_match) in (&self.regex).find_matches(&token)? {
|
||||
if is_match {
|
||||
new_token.push_str(&self.content);
|
||||
} else {
|
||||
new_token.push_str(&token[start..stop]);
|
||||
}
|
||||
}
|
||||
Ok(new_token)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@ -124,4 +146,14 @@ mod tests {
|
||||
assert_eq!(serde_json::to_string(&replace).unwrap(), replace_s);
|
||||
assert_eq!(serde_json::from_str::<Replace>(replace_s).unwrap(), replace);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_replace_decode() {
|
||||
let original = vec!["hello".to_string(), "_hello".to_string()];
|
||||
let replace = Replace::new("_", " ").unwrap();
|
||||
assert_eq!(
|
||||
replace.decode_chain(original).unwrap(),
|
||||
vec!["hello", " hello"]
|
||||
);
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user