mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Adding Replace
to decoder (to undo the Replace Normalizer for (#1195)
Metaspace split).
This commit is contained in:
7
bindings/node/lib/bindings/decoders.d.ts
vendored
7
bindings/node/lib/bindings/decoders.d.ts
vendored
@ -12,6 +12,13 @@ interface Decoder {
|
|||||||
*/
|
*/
|
||||||
export function byteLevelDecoder(): Decoder;
|
export function byteLevelDecoder(): Decoder;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Instantiate a new Replace Decoder
|
||||||
|
* @param [pattern] The pattern to replace
|
||||||
|
* @param [content] The replacement.
|
||||||
|
*/
|
||||||
|
export function replaceDecoder(pattern: string, content: string): Decoder;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Instantiate a new WordPiece Decoder
|
* Instantiate a new WordPiece Decoder
|
||||||
* @param [prefix='##'] The prefix to use for subwords that are not a beginning-of-word
|
* @param [prefix='##'] The prefix to use for subwords that are not a beginning-of-word
|
||||||
|
@ -2,6 +2,7 @@ const native = require("./native");
|
|||||||
|
|
||||||
module.exports = {
|
module.exports = {
|
||||||
byteLevelDecoder: native.decoders_ByteLevel,
|
byteLevelDecoder: native.decoders_ByteLevel,
|
||||||
|
replaceDecoder: native.decoders_Replace,
|
||||||
wordPieceDecoder: native.decoders_WordPiece,
|
wordPieceDecoder: native.decoders_WordPiece,
|
||||||
byteFallbackDecoder: native.decoders_ByteFallback,
|
byteFallbackDecoder: native.decoders_ByteFallback,
|
||||||
metaspaceDecoder: native.decoders_Metaspace,
|
metaspaceDecoder: native.decoders_Metaspace,
|
||||||
|
@ -3,6 +3,7 @@ import {
|
|||||||
byteFallbackDecoder,
|
byteFallbackDecoder,
|
||||||
ctcDecoder,
|
ctcDecoder,
|
||||||
metaspaceDecoder,
|
metaspaceDecoder,
|
||||||
|
replaceDecoder,
|
||||||
sequenceDecoder,
|
sequenceDecoder,
|
||||||
wordPieceDecoder,
|
wordPieceDecoder,
|
||||||
} from "./decoders";
|
} from "./decoders";
|
||||||
@ -44,6 +45,12 @@ describe("byteFallbackDecoder", () => {
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe("replaceDecoder", () => {
|
||||||
|
it("can decode arrays of strings", () => {
|
||||||
|
expect(replaceDecoder("_", " ").decode(["Hello", "_Hello"])).toEqual("Hello Hello");
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
describe("metaspaceDecoder", () => {
|
describe("metaspaceDecoder", () => {
|
||||||
it("accepts `undefined` as first parameter", () => {
|
it("accepts `undefined` as first parameter", () => {
|
||||||
expect(metaspaceDecoder(undefined)).toBeDefined();
|
expect(metaspaceDecoder(undefined)).toBeDefined();
|
||||||
|
@ -57,6 +57,20 @@ fn byte_level(mut cx: FunctionContext) -> JsResult<JsDecoder> {
|
|||||||
Ok(decoder)
|
Ok(decoder)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// replace()
|
||||||
|
fn replace(mut cx: FunctionContext) -> JsResult<JsDecoder> {
|
||||||
|
let pattern: String = cx.extract::<String>(0)?;
|
||||||
|
let content: String = cx.extract::<String>(1)?;
|
||||||
|
let mut decoder = JsDecoder::new::<_, JsDecoder, _>(&mut cx, vec![])?;
|
||||||
|
let guard = cx.lock();
|
||||||
|
decoder.borrow_mut(&guard).decoder = Some(Arc::new(
|
||||||
|
tk::normalizers::replace::Replace::new(pattern, content)
|
||||||
|
.map_err(|e| Error(e.to_string()))?
|
||||||
|
.into(),
|
||||||
|
));
|
||||||
|
Ok(decoder)
|
||||||
|
}
|
||||||
|
|
||||||
/// wordpiece(prefix: String = "##", cleanup: bool)
|
/// wordpiece(prefix: String = "##", cleanup: bool)
|
||||||
fn wordpiece(mut cx: FunctionContext) -> JsResult<JsDecoder> {
|
fn wordpiece(mut cx: FunctionContext) -> JsResult<JsDecoder> {
|
||||||
let prefix = cx
|
let prefix = cx
|
||||||
@ -156,6 +170,7 @@ fn sequence(mut cx: FunctionContext) -> JsResult<JsDecoder> {
|
|||||||
/// Register everything here
|
/// Register everything here
|
||||||
pub fn register(m: &mut ModuleContext, prefix: &str) -> NeonResult<()> {
|
pub fn register(m: &mut ModuleContext, prefix: &str) -> NeonResult<()> {
|
||||||
m.export_function(&format!("{}_ByteLevel", prefix), byte_level)?;
|
m.export_function(&format!("{}_ByteLevel", prefix), byte_level)?;
|
||||||
|
m.export_function(&format!("{}_Replace", prefix), replace)?;
|
||||||
m.export_function(&format!("{}_WordPiece", prefix), wordpiece)?;
|
m.export_function(&format!("{}_WordPiece", prefix), wordpiece)?;
|
||||||
m.export_function(&format!("{}_ByteFallback", prefix), byte_fallback)?;
|
m.export_function(&format!("{}_ByteFallback", prefix), byte_fallback)?;
|
||||||
m.export_function(&format!("{}_Metaspace", prefix), metaspace)?;
|
m.export_function(&format!("{}_Metaspace", prefix), metaspace)?;
|
||||||
|
@ -3,6 +3,7 @@ from .. import decoders
|
|||||||
|
|
||||||
Decoder = decoders.Decoder
|
Decoder = decoders.Decoder
|
||||||
ByteLevel = decoders.ByteLevel
|
ByteLevel = decoders.ByteLevel
|
||||||
|
Replace = decoders.Replace
|
||||||
WordPiece = decoders.WordPiece
|
WordPiece = decoders.WordPiece
|
||||||
ByteFallback = decoders.ByteFallback
|
ByteFallback = decoders.ByteFallback
|
||||||
Metaspace = decoders.Metaspace
|
Metaspace = decoders.Metaspace
|
||||||
|
@ -150,6 +150,29 @@ class Metaspace(Decoder):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
class Replace(Decoder):
|
||||||
|
"""
|
||||||
|
Replace Decoder
|
||||||
|
|
||||||
|
This decoder is to be used in tandem with the :class:`~tokenizers.pre_tokenizers.Replace`
|
||||||
|
:class:`~tokenizers.pre_tokenizers.PreTokenizer`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, pattern, content):
|
||||||
|
pass
|
||||||
|
def decode(self, tokens):
|
||||||
|
"""
|
||||||
|
Decode the given list of tokens to a final string
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokens (:obj:`List[str]`):
|
||||||
|
The list of tokens to decode
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:obj:`str`: The decoded string
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
class Sequence(Decoder):
|
class Sequence(Decoder):
|
||||||
"""
|
"""
|
||||||
Sequence Decoder
|
Sequence Decoder
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
use std::sync::{Arc, RwLock};
|
use std::sync::{Arc, RwLock};
|
||||||
|
|
||||||
use crate::utils::PyChar;
|
use crate::utils::PyChar;
|
||||||
|
use crate::utils::PyPattern;
|
||||||
use pyo3::exceptions;
|
use pyo3::exceptions;
|
||||||
use pyo3::prelude::*;
|
use pyo3::prelude::*;
|
||||||
use pyo3::types::*;
|
use pyo3::types::*;
|
||||||
@ -14,6 +15,7 @@ use tk::decoders::metaspace::Metaspace;
|
|||||||
use tk::decoders::sequence::Sequence;
|
use tk::decoders::sequence::Sequence;
|
||||||
use tk::decoders::wordpiece::WordPiece;
|
use tk::decoders::wordpiece::WordPiece;
|
||||||
use tk::decoders::DecoderWrapper;
|
use tk::decoders::DecoderWrapper;
|
||||||
|
use tk::normalizers::replace::Replace;
|
||||||
use tk::Decoder;
|
use tk::Decoder;
|
||||||
use tokenizers as tk;
|
use tokenizers as tk;
|
||||||
|
|
||||||
@ -46,6 +48,7 @@ impl PyDecoder {
|
|||||||
Py::new(py, (PyByteFallbackDec {}, base))?.into_py(py)
|
Py::new(py, (PyByteFallbackDec {}, base))?.into_py(py)
|
||||||
}
|
}
|
||||||
DecoderWrapper::ByteLevel(_) => Py::new(py, (PyByteLevelDec {}, base))?.into_py(py),
|
DecoderWrapper::ByteLevel(_) => Py::new(py, (PyByteLevelDec {}, base))?.into_py(py),
|
||||||
|
DecoderWrapper::Replace(_) => Py::new(py, (PyReplaceDec {}, base))?.into_py(py),
|
||||||
DecoderWrapper::BPE(_) => Py::new(py, (PyBPEDecoder {}, base))?.into_py(py),
|
DecoderWrapper::BPE(_) => Py::new(py, (PyBPEDecoder {}, base))?.into_py(py),
|
||||||
DecoderWrapper::CTC(_) => Py::new(py, (PyCTCDecoder {}, base))?.into_py(py),
|
DecoderWrapper::CTC(_) => Py::new(py, (PyCTCDecoder {}, base))?.into_py(py),
|
||||||
DecoderWrapper::Sequence(_) => {
|
DecoderWrapper::Sequence(_) => {
|
||||||
@ -159,6 +162,24 @@ impl PyByteLevelDec {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Replace Decoder
|
||||||
|
///
|
||||||
|
/// This decoder is to be used in tandem with the :class:`~tokenizers.pre_tokenizers.Replace`
|
||||||
|
/// :class:`~tokenizers.pre_tokenizers.PreTokenizer`.
|
||||||
|
#[pyclass(extends=PyDecoder, module = "tokenizers.decoders", name = "Replace")]
|
||||||
|
#[pyo3(text_signature = "(self, pattern, content)")]
|
||||||
|
pub struct PyReplaceDec {}
|
||||||
|
#[pymethods]
|
||||||
|
impl PyReplaceDec {
|
||||||
|
#[new]
|
||||||
|
fn new(pattern: PyPattern, content: String) -> PyResult<(Self, PyDecoder)> {
|
||||||
|
Ok((
|
||||||
|
PyReplaceDec {},
|
||||||
|
ToPyResult(Replace::new(pattern, content)).into_py()?.into(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// WordPiece Decoder
|
/// WordPiece Decoder
|
||||||
///
|
///
|
||||||
/// Args:
|
/// Args:
|
||||||
@ -473,6 +494,7 @@ impl Decoder for PyDecoderWrapper {
|
|||||||
pub fn decoders(_py: Python, m: &PyModule) -> PyResult<()> {
|
pub fn decoders(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||||
m.add_class::<PyDecoder>()?;
|
m.add_class::<PyDecoder>()?;
|
||||||
m.add_class::<PyByteLevelDec>()?;
|
m.add_class::<PyByteLevelDec>()?;
|
||||||
|
m.add_class::<PyReplaceDec>()?;
|
||||||
m.add_class::<PyWordPieceDec>()?;
|
m.add_class::<PyWordPieceDec>()?;
|
||||||
m.add_class::<PyByteFallbackDec>()?;
|
m.add_class::<PyByteFallbackDec>()?;
|
||||||
m.add_class::<PyMetaspaceDec>()?;
|
m.add_class::<PyMetaspaceDec>()?;
|
||||||
|
@ -3,7 +3,17 @@ import pickle
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from tokenizers.decoders import CTC, BPEDecoder, ByteLevel, Decoder, Metaspace, Sequence, WordPiece, ByteFallback
|
from tokenizers.decoders import (
|
||||||
|
CTC,
|
||||||
|
BPEDecoder,
|
||||||
|
ByteLevel,
|
||||||
|
Decoder,
|
||||||
|
Metaspace,
|
||||||
|
Sequence,
|
||||||
|
WordPiece,
|
||||||
|
ByteFallback,
|
||||||
|
Replace,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestByteLevel:
|
class TestByteLevel:
|
||||||
@ -24,6 +34,18 @@ class TestByteLevel:
|
|||||||
assert isinstance(reloaded, ByteLevel)
|
assert isinstance(reloaded, ByteLevel)
|
||||||
|
|
||||||
|
|
||||||
|
class TestReplace:
|
||||||
|
def test_instantiate(self):
|
||||||
|
assert Replace("_", " ") is not None
|
||||||
|
assert isinstance(Replace("_", " "), Decoder)
|
||||||
|
assert isinstance(Replace("_", " "), Replace)
|
||||||
|
# assert isinstance(pickle.loads(pickle.dumps(Replace("_", " "))), Replace)
|
||||||
|
|
||||||
|
def test_decoding(self):
|
||||||
|
decoder = Replace("_", " ")
|
||||||
|
assert decoder.decode(["My", "_name", "_is", "_John"]) == "My name is John"
|
||||||
|
|
||||||
|
|
||||||
class TestWordPiece:
|
class TestWordPiece:
|
||||||
def test_instantiate(self):
|
def test_instantiate(self):
|
||||||
assert WordPiece() is not None
|
assert WordPiece() is not None
|
||||||
|
@ -15,6 +15,7 @@ use crate::decoders::byte_fallback::ByteFallback;
|
|||||||
use crate::decoders::ctc::CTC;
|
use crate::decoders::ctc::CTC;
|
||||||
use crate::decoders::sequence::Sequence;
|
use crate::decoders::sequence::Sequence;
|
||||||
use crate::decoders::wordpiece::WordPiece;
|
use crate::decoders::wordpiece::WordPiece;
|
||||||
|
use crate::normalizers::replace::Replace;
|
||||||
use crate::pre_tokenizers::byte_level::ByteLevel;
|
use crate::pre_tokenizers::byte_level::ByteLevel;
|
||||||
use crate::pre_tokenizers::metaspace::Metaspace;
|
use crate::pre_tokenizers::metaspace::Metaspace;
|
||||||
use crate::{Decoder, Result};
|
use crate::{Decoder, Result};
|
||||||
@ -28,6 +29,7 @@ pub enum DecoderWrapper {
|
|||||||
Metaspace(Metaspace),
|
Metaspace(Metaspace),
|
||||||
CTC(CTC),
|
CTC(CTC),
|
||||||
Sequence(Sequence),
|
Sequence(Sequence),
|
||||||
|
Replace(Replace),
|
||||||
// XXX: This is an untagged enum, which unfortunately means order
|
// XXX: This is an untagged enum, which unfortunately means order
|
||||||
// is **CRITICAL**. We absolutely need to make sure order is correct.
|
// is **CRITICAL**. We absolutely need to make sure order is correct.
|
||||||
// Since byte fallback is parameter free, is **has** to be last, and will
|
// Since byte fallback is parameter free, is **has** to be last, and will
|
||||||
@ -44,6 +46,7 @@ impl Decoder for DecoderWrapper {
|
|||||||
Self::WordPiece(wp) => wp.decode_chain(tokens),
|
Self::WordPiece(wp) => wp.decode_chain(tokens),
|
||||||
Self::CTC(ctc) => ctc.decode_chain(tokens),
|
Self::CTC(ctc) => ctc.decode_chain(tokens),
|
||||||
Self::Sequence(seq) => seq.decode_chain(tokens),
|
Self::Sequence(seq) => seq.decode_chain(tokens),
|
||||||
|
Self::Replace(seq) => seq.decode_chain(tokens),
|
||||||
Self::ByteFallback(bf) => bf.decode_chain(tokens),
|
Self::ByteFallback(bf) => bf.decode_chain(tokens),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -56,6 +59,7 @@ impl_enum_from!(Metaspace, DecoderWrapper, Metaspace);
|
|||||||
impl_enum_from!(WordPiece, DecoderWrapper, WordPiece);
|
impl_enum_from!(WordPiece, DecoderWrapper, WordPiece);
|
||||||
impl_enum_from!(CTC, DecoderWrapper, CTC);
|
impl_enum_from!(CTC, DecoderWrapper, CTC);
|
||||||
impl_enum_from!(Sequence, DecoderWrapper, Sequence);
|
impl_enum_from!(Sequence, DecoderWrapper, Sequence);
|
||||||
|
impl_enum_from!(Replace, DecoderWrapper, Replace);
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
use crate::tokenizer::pattern::Pattern;
|
||||||
|
use crate::tokenizer::Decoder;
|
||||||
use crate::tokenizer::{NormalizedString, Normalizer, Result};
|
use crate::tokenizer::{NormalizedString, Normalizer, Result};
|
||||||
use crate::utils::SysRegex;
|
use crate::utils::SysRegex;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
@ -83,6 +85,26 @@ impl Normalizer for Replace {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Decoder for Replace {
|
||||||
|
fn decode_chain(&self, tokens: Vec<String>) -> Result<Vec<String>> {
|
||||||
|
tokens
|
||||||
|
.into_iter()
|
||||||
|
.map(|token| -> Result<String> {
|
||||||
|
let mut new_token = "".to_string();
|
||||||
|
|
||||||
|
for ((start, stop), is_match) in (&self.regex).find_matches(&token)? {
|
||||||
|
if is_match {
|
||||||
|
new_token.push_str(&self.content);
|
||||||
|
} else {
|
||||||
|
new_token.push_str(&token[start..stop]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(new_token)
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
@ -124,4 +146,14 @@ mod tests {
|
|||||||
assert_eq!(serde_json::to_string(&replace).unwrap(), replace_s);
|
assert_eq!(serde_json::to_string(&replace).unwrap(), replace_s);
|
||||||
assert_eq!(serde_json::from_str::<Replace>(replace_s).unwrap(), replace);
|
assert_eq!(serde_json::from_str::<Replace>(replace_s).unwrap(), replace);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_replace_decode() {
|
||||||
|
let original = vec!["hello".to_string(), "_hello".to_string()];
|
||||||
|
let replace = Replace::new("_", " ").unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
replace.decode_chain(original).unwrap(),
|
||||||
|
vec!["hello", " hello"]
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user