diff --git a/bindings/node/lib/bindings/normalizers.d.ts b/bindings/node/lib/bindings/normalizers.d.ts index 57668683..b26a7704 100644 --- a/bindings/node/lib/bindings/normalizers.d.ts +++ b/bindings/node/lib/bindings/normalizers.d.ts @@ -78,6 +78,12 @@ export function lowercaseNormalizer(): Normalizer; */ export function stripNormalizer(left?: boolean, right?: boolean): Normalizer; +/** + * Returns a new Prepend Normalizer + * @param [prepend] The string to prepend + */ +export function prependNormalizer(prepend: string): Normalizer; + /** * Returns a new StripAccents Normalizer */ diff --git a/bindings/node/lib/bindings/normalizers.js b/bindings/node/lib/bindings/normalizers.js index 46fcdfbb..34a6f47f 100644 --- a/bindings/node/lib/bindings/normalizers.js +++ b/bindings/node/lib/bindings/normalizers.js @@ -9,6 +9,7 @@ module.exports = { sequenceNormalizer: native.normalizers_Sequence, lowercaseNormalizer: native.normalizers_Lowercase, stripNormalizer: native.normalizers_Strip, + prependNormalizer: native.normalizers_Prepend, stripAccentsNormalizer: native.normalizers_StripAccents, nmtNormalizer: native.normalizers_Nmt, precompiledNormalizer: native.normalizers_Precompiled, diff --git a/bindings/node/lib/bindings/normalizers.test.ts b/bindings/node/lib/bindings/normalizers.test.ts index a774ba54..5cf4801b 100644 --- a/bindings/node/lib/bindings/normalizers.test.ts +++ b/bindings/node/lib/bindings/normalizers.test.ts @@ -1,4 +1,8 @@ -import { stripAccentsNormalizer, stripNormalizer } from "./normalizers"; +import { + prependNormalizer, + stripAccentsNormalizer, + stripNormalizer, +} from "./normalizers"; describe("stripNormalizer", () => { it("instantiates with no parameters", () => { @@ -24,6 +28,12 @@ describe("stripNormalizer", () => { expect(normalizer.constructor.name).toEqual("Normalizer"); }); + it("prepend instantiates with one parameter", () => { + const normalizer = prependNormalizer("_"); + expect(normalizer.constructor.name).toEqual("Normalizer"); + expect(normalizer.normalizeString("Hello")).toEqual("_Hello"); + }); + it("can normalize strings", () => { const normalizer = stripNormalizer(); expect(normalizer.normalizeString(" Hello there ")).toEqual("Hello there"); diff --git a/bindings/node/native/src/normalizers.rs b/bindings/node/native/src/normalizers.rs index 0e5fa0c0..ca84e4c4 100644 --- a/bindings/node/native/src/normalizers.rs +++ b/bindings/node/native/src/normalizers.rs @@ -175,6 +175,18 @@ fn strip(mut cx: FunctionContext) -> JsResult { Ok(normalizer) } + +/// prepend(prepend: string) +fn prepend(mut cx: FunctionContext) -> JsResult { + let prepend: String = cx.extract::(0)?; + + let mut normalizer = JsNormalizer::new::<_, JsNormalizer, _>(&mut cx, vec![])?; + let guard = cx.lock(); + normalizer.borrow_mut(&guard).normalizer = + Some(tk::normalizers::prepend::Prepend::new(prepend).into()); + + Ok(normalizer) +} /// strip_accents() fn strip_accents(mut cx: FunctionContext) -> JsResult { let mut normalizer = JsNormalizer::new::<_, JsNormalizer, _>(&mut cx, vec![])?; @@ -267,6 +279,7 @@ pub fn register(m: &mut ModuleContext, prefix: &str) -> NeonResult<()> { m.export_function(&format!("{}_Sequence", prefix), sequence)?; m.export_function(&format!("{}_Lowercase", prefix), lowercase)?; m.export_function(&format!("{}_Strip", prefix), strip)?; + m.export_function(&format!("{}_Prepend", prefix), prepend)?; m.export_function(&format!("{}_StripAccents", prefix), strip_accents)?; m.export_function(&format!("{}_Nmt", prefix), nmt)?; m.export_function(&format!("{}_Precompiled", prefix), precompiled)?; diff --git a/bindings/python/py_src/tokenizers/normalizers/__init__.py b/bindings/python/py_src/tokenizers/normalizers/__init__.py index 6a34c1a8..15a16f1e 100644 --- a/bindings/python/py_src/tokenizers/normalizers/__init__.py +++ b/bindings/python/py_src/tokenizers/normalizers/__init__.py @@ -9,6 +9,7 @@ NFC = normalizers.NFC NFKC = normalizers.NFKC Sequence = normalizers.Sequence Lowercase = normalizers.Lowercase +Prepend = normalizers.Prepend Strip = normalizers.Strip StripAccents = normalizers.StripAccents Nmt = normalizers.Nmt diff --git a/bindings/python/py_src/tokenizers/normalizers/__init__.pyi b/bindings/python/py_src/tokenizers/normalizers/__init__.pyi index dbb3e95e..09c2d839 100644 --- a/bindings/python/py_src/tokenizers/normalizers/__init__.pyi +++ b/bindings/python/py_src/tokenizers/normalizers/__init__.pyi @@ -379,6 +379,46 @@ class Precompiled(Normalizer): """ pass +class Prepend(Normalizer): + """ + Prepend normalizer + """ + + def __init__(self, prepend): + pass + def normalize(self, normalized): + """ + Normalize a :class:`~tokenizers.NormalizedString` in-place + + This method allows to modify a :class:`~tokenizers.NormalizedString` to + keep track of the alignment information. If you just want to see the result + of the normalization on a raw string, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize_str` + + Args: + normalized (:class:`~tokenizers.NormalizedString`): + The normalized string on which to apply this + :class:`~tokenizers.normalizers.Normalizer` + """ + pass + def normalize_str(self, sequence): + """ + Normalize the given string + + This method provides a way to visualize the effect of a + :class:`~tokenizers.normalizers.Normalizer` but it does not keep track of the alignment + information. If you need to get/convert offsets, you can use + :meth:`~tokenizers.normalizers.Normalizer.normalize` + + Args: + sequence (:obj:`str`): + A string to normalize + + Returns: + :obj:`str`: A string after normalization + """ + pass + class Replace(Normalizer): """ Replace normalizer diff --git a/bindings/python/src/normalizers.rs b/bindings/python/src/normalizers.rs index d825482d..442a01c1 100644 --- a/bindings/python/src/normalizers.rs +++ b/bindings/python/src/normalizers.rs @@ -9,8 +9,8 @@ use crate::utils::{PyNormalizedString, PyNormalizedStringRefMut, PyPattern}; use serde::ser::SerializeStruct; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use tk::normalizers::{ - BertNormalizer, Lowercase, Nmt, NormalizerWrapper, Precompiled, Replace, Strip, StripAccents, - NFC, NFD, NFKC, NFKD, + BertNormalizer, Lowercase, Nmt, NormalizerWrapper, Precompiled, Prepend, Replace, Strip, + StripAccents, NFC, NFD, NFKC, NFKD, }; use tk::{NormalizedString, Normalizer}; use tokenizers as tk; @@ -69,6 +69,7 @@ impl PyNormalizer { NormalizerWrapper::StripNormalizer(_) => { Py::new(py, (PyBertNormalizer {}, base))?.into_py(py) } + NormalizerWrapper::Prepend(_) => Py::new(py, (PyPrepend {}, base))?.into_py(py), NormalizerWrapper::StripAccents(_) => { Py::new(py, (PyStripAccents {}, base))?.into_py(py) } @@ -172,7 +173,8 @@ macro_rules! getter { let super_ = $self.as_ref(); if let PyNormalizerTypeWrapper::Single(ref norm) = super_.normalizer { let wrapper = norm.read().unwrap(); - if let PyNormalizerWrapper::Wrapped(NormalizerWrapper::$variant(o)) = *wrapper { + if let PyNormalizerWrapper::Wrapped(NormalizerWrapper::$variant(o)) = (*wrapper).clone() + { o.$name } else { unreachable!() @@ -413,6 +415,29 @@ impl PyStrip { } } +/// Prepend normalizer +#[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name = "Prepend")] +#[pyo3(text_signature = "(self, prepend)")] +pub struct PyPrepend {} +#[pymethods] +impl PyPrepend { + #[getter] + fn get_prepend(self_: PyRef) -> String { + getter!(self_, Prepend, prepend) + } + + #[setter] + fn set_prepend(self_: PyRef, prepend: String) { + setter!(self_, Prepend, prepend, prepend) + } + + #[new] + #[pyo3(signature = (prepend="▁".to_string()))] + fn new(prepend: String) -> (Self, PyNormalizer) { + (PyPrepend {}, Prepend::new(prepend).into()) + } +} + /// StripAccents normalizer #[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name = "StripAccents")] #[pyo3(text_signature = "(self)")] @@ -624,6 +649,7 @@ pub fn normalizers(_py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/bindings/python/tests/bindings/test_normalizers.py b/bindings/python/tests/bindings/test_normalizers.py index f6f69015..cf9f3d1a 100644 --- a/bindings/python/tests/bindings/test_normalizers.py +++ b/bindings/python/tests/bindings/test_normalizers.py @@ -4,7 +4,7 @@ import pytest from tokenizers import NormalizedString, Tokenizer from tokenizers.models import BPE -from tokenizers.normalizers import BertNormalizer, Lowercase, Normalizer, Sequence, Strip +from tokenizers.normalizers import BertNormalizer, Lowercase, Normalizer, Sequence, Strip, Prepend class TestBertNormalizer: @@ -119,6 +119,28 @@ class TestStrip: assert normalizer.right == False +class TestPrepend: + def test_instantiate(self): + assert isinstance(Prepend("▁"), Normalizer) + assert isinstance(Prepend("▁"), Prepend) + assert isinstance(pickle.loads(pickle.dumps(Prepend("▁"))), Prepend) + + def test_prepend(self): + normalizer = Prepend(prepend="▁") + + output = normalizer.normalize_str("hello") + assert output == "▁hello" + + def test_can_modify(self): + normalizer = Prepend("▁") + + assert normalizer.prepend == "▁" + + # Modify these + normalizer.prepend = "-" + assert normalizer.prepend == "-" + + class TestCustomNormalizer: class BadCustomNormalizer: def normalize(self, normalized, wrong): diff --git a/tokenizers/src/normalizers/mod.rs b/tokenizers/src/normalizers/mod.rs index b563fa5c..8ac4c58e 100644 --- a/tokenizers/src/normalizers/mod.rs +++ b/tokenizers/src/normalizers/mod.rs @@ -1,5 +1,6 @@ pub mod bert; pub mod precompiled; +pub mod prepend; pub mod replace; pub mod strip; pub mod unicode; @@ -7,6 +8,7 @@ pub mod utils; pub use crate::normalizers::bert::BertNormalizer; pub use crate::normalizers::precompiled::Precompiled; +pub use crate::normalizers::prepend::Prepend; pub use crate::normalizers::replace::Replace; pub use crate::normalizers::strip::{Strip, StripAccents}; pub use crate::normalizers::unicode::{Nmt, NFC, NFD, NFKC, NFKD}; @@ -32,6 +34,7 @@ pub enum NormalizerWrapper { Nmt(Nmt), Precompiled(Precompiled), Replace(Replace), + Prepend(Prepend), } impl Normalizer for NormalizerWrapper { @@ -49,6 +52,7 @@ impl Normalizer for NormalizerWrapper { Self::Nmt(lc) => lc.normalize(normalized), Self::Precompiled(lc) => lc.normalize(normalized), Self::Replace(lc) => lc.normalize(normalized), + Self::Prepend(lc) => lc.normalize(normalized), } } } @@ -65,3 +69,4 @@ impl_enum_from!(Lowercase, NormalizerWrapper, Lowercase); impl_enum_from!(Nmt, NormalizerWrapper, Nmt); impl_enum_from!(Precompiled, NormalizerWrapper, Precompiled); impl_enum_from!(Replace, NormalizerWrapper, Replace); +impl_enum_from!(Prepend, NormalizerWrapper, Prepend); diff --git a/tokenizers/src/normalizers/prepend.rs b/tokenizers/src/normalizers/prepend.rs new file mode 100644 index 00000000..4e318c25 --- /dev/null +++ b/tokenizers/src/normalizers/prepend.rs @@ -0,0 +1,62 @@ +use crate::tokenizer::{NormalizedString, Normalizer, Result}; +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Debug, Deserialize, Serialize)] +#[serde(tag = "type")] +pub struct Prepend { + pub prepend: String, +} + +impl Prepend { + pub fn new(prepend: String) -> Self { + Self { prepend } + } +} + +impl Normalizer for Prepend { + /// Strip the normalized string inplace + fn normalize(&self, normalized: &mut NormalizedString) -> Result<()> { + if !normalized.is_empty() { + normalized.prepend(&self.prepend); + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_prepend() { + let original = "Hello"; + let normalized = "▁Hello"; + assert_ne!(original, normalized); + let mut n = NormalizedString::from(original); + let prepend = Prepend::new("▁".to_string()); + prepend.normalize(&mut n).unwrap(); + assert_eq!(&n.get(), &normalized); + assert_eq!( + n, + NormalizedString::new( + original.to_string(), + normalized.to_string(), + vec![ + (0, 1), + (0, 1), + (0, 1), + (0, 1), + (1, 2), + (2, 3), + (3, 4), + (4, 5) + ], + 0 + ) + ); + assert_eq!( + n.alignments_original(), + vec![(0, 4), (4, 5), (5, 6), (6, 7), (7, 8)] + ); + } +}