mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-23 00:35:35 +00:00
Adding Replace
to decoder (to undo the Replace Normalizer for (#1195)
Metaspace split).
This commit is contained in:
@ -3,6 +3,7 @@ from .. import decoders
|
||||
|
||||
Decoder = decoders.Decoder
|
||||
ByteLevel = decoders.ByteLevel
|
||||
Replace = decoders.Replace
|
||||
WordPiece = decoders.WordPiece
|
||||
ByteFallback = decoders.ByteFallback
|
||||
Metaspace = decoders.Metaspace
|
||||
|
@ -150,6 +150,29 @@ class Metaspace(Decoder):
|
||||
"""
|
||||
pass
|
||||
|
||||
class Replace(Decoder):
|
||||
"""
|
||||
Replace Decoder
|
||||
|
||||
This decoder is to be used in tandem with the :class:`~tokenizers.pre_tokenizers.Replace`
|
||||
:class:`~tokenizers.pre_tokenizers.PreTokenizer`.
|
||||
"""
|
||||
|
||||
def __init__(self, pattern, content):
|
||||
pass
|
||||
def decode(self, tokens):
|
||||
"""
|
||||
Decode the given list of tokens to a final string
|
||||
|
||||
Args:
|
||||
tokens (:obj:`List[str]`):
|
||||
The list of tokens to decode
|
||||
|
||||
Returns:
|
||||
:obj:`str`: The decoded string
|
||||
"""
|
||||
pass
|
||||
|
||||
class Sequence(Decoder):
|
||||
"""
|
||||
Sequence Decoder
|
||||
|
@ -1,6 +1,7 @@
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
use crate::utils::PyChar;
|
||||
use crate::utils::PyPattern;
|
||||
use pyo3::exceptions;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::*;
|
||||
@ -14,6 +15,7 @@ use tk::decoders::metaspace::Metaspace;
|
||||
use tk::decoders::sequence::Sequence;
|
||||
use tk::decoders::wordpiece::WordPiece;
|
||||
use tk::decoders::DecoderWrapper;
|
||||
use tk::normalizers::replace::Replace;
|
||||
use tk::Decoder;
|
||||
use tokenizers as tk;
|
||||
|
||||
@ -46,6 +48,7 @@ impl PyDecoder {
|
||||
Py::new(py, (PyByteFallbackDec {}, base))?.into_py(py)
|
||||
}
|
||||
DecoderWrapper::ByteLevel(_) => Py::new(py, (PyByteLevelDec {}, base))?.into_py(py),
|
||||
DecoderWrapper::Replace(_) => Py::new(py, (PyReplaceDec {}, base))?.into_py(py),
|
||||
DecoderWrapper::BPE(_) => Py::new(py, (PyBPEDecoder {}, base))?.into_py(py),
|
||||
DecoderWrapper::CTC(_) => Py::new(py, (PyCTCDecoder {}, base))?.into_py(py),
|
||||
DecoderWrapper::Sequence(_) => {
|
||||
@ -159,6 +162,24 @@ impl PyByteLevelDec {
|
||||
}
|
||||
}
|
||||
|
||||
/// Replace Decoder
|
||||
///
|
||||
/// This decoder is to be used in tandem with the :class:`~tokenizers.pre_tokenizers.Replace`
|
||||
/// :class:`~tokenizers.pre_tokenizers.PreTokenizer`.
|
||||
#[pyclass(extends=PyDecoder, module = "tokenizers.decoders", name = "Replace")]
|
||||
#[pyo3(text_signature = "(self, pattern, content)")]
|
||||
pub struct PyReplaceDec {}
|
||||
#[pymethods]
|
||||
impl PyReplaceDec {
|
||||
#[new]
|
||||
fn new(pattern: PyPattern, content: String) -> PyResult<(Self, PyDecoder)> {
|
||||
Ok((
|
||||
PyReplaceDec {},
|
||||
ToPyResult(Replace::new(pattern, content)).into_py()?.into(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
/// WordPiece Decoder
|
||||
///
|
||||
/// Args:
|
||||
@ -473,6 +494,7 @@ impl Decoder for PyDecoderWrapper {
|
||||
pub fn decoders(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||
m.add_class::<PyDecoder>()?;
|
||||
m.add_class::<PyByteLevelDec>()?;
|
||||
m.add_class::<PyReplaceDec>()?;
|
||||
m.add_class::<PyWordPieceDec>()?;
|
||||
m.add_class::<PyByteFallbackDec>()?;
|
||||
m.add_class::<PyMetaspaceDec>()?;
|
||||
|
@ -3,7 +3,17 @@ import pickle
|
||||
|
||||
import pytest
|
||||
|
||||
from tokenizers.decoders import CTC, BPEDecoder, ByteLevel, Decoder, Metaspace, Sequence, WordPiece, ByteFallback
|
||||
from tokenizers.decoders import (
|
||||
CTC,
|
||||
BPEDecoder,
|
||||
ByteLevel,
|
||||
Decoder,
|
||||
Metaspace,
|
||||
Sequence,
|
||||
WordPiece,
|
||||
ByteFallback,
|
||||
Replace,
|
||||
)
|
||||
|
||||
|
||||
class TestByteLevel:
|
||||
@ -24,6 +34,18 @@ class TestByteLevel:
|
||||
assert isinstance(reloaded, ByteLevel)
|
||||
|
||||
|
||||
class TestReplace:
|
||||
def test_instantiate(self):
|
||||
assert Replace("_", " ") is not None
|
||||
assert isinstance(Replace("_", " "), Decoder)
|
||||
assert isinstance(Replace("_", " "), Replace)
|
||||
# assert isinstance(pickle.loads(pickle.dumps(Replace("_", " "))), Replace)
|
||||
|
||||
def test_decoding(self):
|
||||
decoder = Replace("_", " ")
|
||||
assert decoder.decode(["My", "_name", "_is", "_John"]) == "My name is John"
|
||||
|
||||
|
||||
class TestWordPiece:
|
||||
def test_instantiate(self):
|
||||
assert WordPiece() is not None
|
||||
|
Reference in New Issue
Block a user