Add bytelevel normalizer to fix decode when adding tokens to BPE (#1555)

* feature dependent test

* nit about 嗎

* update

* actuallyfix it

* update the test

add it

fix

* stub

* Update tokenizers/src/pre_tokenizers/byte_level.rs

Co-authored-by: Luc Georges <McPatate@users.noreply.github.com>

* skip failing test

* add normalizer to init

---------

Co-authored-by: Luc Georges <McPatate@users.noreply.github.com>
This commit is contained in:
Arthur
2024-07-15 12:12:03 +02:00
committed by GitHub
parent f2a44dc5d1
commit 4ea2f235b0
9 changed files with 335 additions and 6 deletions

View File

@ -9,8 +9,8 @@ use crate::utils::{PyNormalizedString, PyNormalizedStringRefMut, PyPattern};
use serde::ser::SerializeStruct;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use tk::normalizers::{
BertNormalizer, Lowercase, Nmt, NormalizerWrapper, Precompiled, Prepend, Replace, Strip,
StripAccents, NFC, NFD, NFKC, NFKD,
BertNormalizer, ByteLevel, Lowercase, Nmt, NormalizerWrapper, Precompiled, Prepend, Replace,
Strip, StripAccents, NFC, NFD, NFKC, NFKD,
};
use tk::{NormalizedString, Normalizer};
use tokenizers as tk;
@ -70,6 +70,9 @@ impl PyNormalizer {
Py::new(py, (PyBertNormalizer {}, base))?.into_py(py)
}
NormalizerWrapper::Prepend(_) => Py::new(py, (PyPrepend {}, base))?.into_py(py),
NormalizerWrapper::ByteLevel(_) => {
Py::new(py, (PyByteLevel {}, base))?.into_py(py)
}
NormalizerWrapper::StripAccents(_) => {
Py::new(py, (PyStripAccents {}, base))?.into_py(py)
}
@ -435,6 +438,18 @@ impl PyPrepend {
}
}
/// Bytelevel Normalizer
#[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name = "ByteLevel")]
pub struct PyByteLevel {}
#[pymethods]
impl PyByteLevel {
#[new]
#[pyo3(text_signature = "(self)")]
fn new() -> (Self, PyNormalizer) {
(PyByteLevel {}, ByteLevel::new().into())
}
}
/// StripAccents normalizer
#[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name = "StripAccents")]
pub struct PyStripAccents {}
@ -647,6 +662,7 @@ pub fn normalizers(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyStrip>()?;
m.add_class::<PyStripAccents>()?;
m.add_class::<PyPrepend>()?;
m.add_class::<PyByteLevel>()?;
m.add_class::<PyNmt>()?;
m.add_class::<PyPrecompiled>()?;
m.add_class::<PyReplace>()?;