Adding Replace to decoder (to undo the Replace Normalizer for (#1195)

Metaspace split).
This commit is contained in:
Nicolas Patry
2023-03-23 23:43:47 +01:00
committed by GitHub
parent 178e294a6a
commit 250d46c676
10 changed files with 135 additions and 1 deletions

View File

@ -3,6 +3,7 @@ from .. import decoders
Decoder = decoders.Decoder
ByteLevel = decoders.ByteLevel
Replace = decoders.Replace
WordPiece = decoders.WordPiece
ByteFallback = decoders.ByteFallback
Metaspace = decoders.Metaspace

View File

@ -150,6 +150,29 @@ class Metaspace(Decoder):
"""
pass
class Replace(Decoder):
"""
Replace Decoder
This decoder is to be used in tandem with the :class:`~tokenizers.pre_tokenizers.Replace`
:class:`~tokenizers.pre_tokenizers.PreTokenizer`.
"""
def __init__(self, pattern, content):
pass
def decode(self, tokens):
"""
Decode the given list of tokens to a final string
Args:
tokens (:obj:`List[str]`):
The list of tokens to decode
Returns:
:obj:`str`: The decoded string
"""
pass
class Sequence(Decoder):
"""
Sequence Decoder

View File

@ -1,6 +1,7 @@
use std::sync::{Arc, RwLock};
use crate::utils::PyChar;
use crate::utils::PyPattern;
use pyo3::exceptions;
use pyo3::prelude::*;
use pyo3::types::*;
@ -14,6 +15,7 @@ use tk::decoders::metaspace::Metaspace;
use tk::decoders::sequence::Sequence;
use tk::decoders::wordpiece::WordPiece;
use tk::decoders::DecoderWrapper;
use tk::normalizers::replace::Replace;
use tk::Decoder;
use tokenizers as tk;
@ -46,6 +48,7 @@ impl PyDecoder {
Py::new(py, (PyByteFallbackDec {}, base))?.into_py(py)
}
DecoderWrapper::ByteLevel(_) => Py::new(py, (PyByteLevelDec {}, base))?.into_py(py),
DecoderWrapper::Replace(_) => Py::new(py, (PyReplaceDec {}, base))?.into_py(py),
DecoderWrapper::BPE(_) => Py::new(py, (PyBPEDecoder {}, base))?.into_py(py),
DecoderWrapper::CTC(_) => Py::new(py, (PyCTCDecoder {}, base))?.into_py(py),
DecoderWrapper::Sequence(_) => {
@ -159,6 +162,24 @@ impl PyByteLevelDec {
}
}
/// Replace Decoder
///
/// This decoder is to be used in tandem with the :class:`~tokenizers.pre_tokenizers.Replace`
/// :class:`~tokenizers.pre_tokenizers.PreTokenizer`.
#[pyclass(extends=PyDecoder, module = "tokenizers.decoders", name = "Replace")]
#[pyo3(text_signature = "(self, pattern, content)")]
pub struct PyReplaceDec {}
#[pymethods]
impl PyReplaceDec {
#[new]
fn new(pattern: PyPattern, content: String) -> PyResult<(Self, PyDecoder)> {
Ok((
PyReplaceDec {},
ToPyResult(Replace::new(pattern, content)).into_py()?.into(),
))
}
}
/// WordPiece Decoder
///
/// Args:
@ -473,6 +494,7 @@ impl Decoder for PyDecoderWrapper {
pub fn decoders(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<PyDecoder>()?;
m.add_class::<PyByteLevelDec>()?;
m.add_class::<PyReplaceDec>()?;
m.add_class::<PyWordPieceDec>()?;
m.add_class::<PyByteFallbackDec>()?;
m.add_class::<PyMetaspaceDec>()?;

View File

@ -3,7 +3,17 @@ import pickle
import pytest
from tokenizers.decoders import CTC, BPEDecoder, ByteLevel, Decoder, Metaspace, Sequence, WordPiece, ByteFallback
from tokenizers.decoders import (
CTC,
BPEDecoder,
ByteLevel,
Decoder,
Metaspace,
Sequence,
WordPiece,
ByteFallback,
Replace,
)
class TestByteLevel:
@ -24,6 +34,18 @@ class TestByteLevel:
assert isinstance(reloaded, ByteLevel)
class TestReplace:
def test_instantiate(self):
assert Replace("_", " ") is not None
assert isinstance(Replace("_", " "), Decoder)
assert isinstance(Replace("_", " "), Replace)
# assert isinstance(pickle.loads(pickle.dumps(Replace("_", " "))), Replace)
def test_decoding(self):
decoder = Replace("_", " ")
assert decoder.decode(["My", "_name", "_is", "_John"]) == "My name is John"
class TestWordPiece:
def test_instantiate(self):
assert WordPiece() is not None