diff --git a/bindings/python/src/lib.rs b/bindings/python/src/lib.rs index 25fdc012..2ecd6709 100644 --- a/bindings/python/src/lib.rs +++ b/bindings/python/src/lib.rs @@ -2,6 +2,7 @@ mod decoders; mod encoding; mod error; mod models; +mod normalizers; mod pre_tokenizers; mod processors; mod token; @@ -55,6 +56,14 @@ fn processors(_py: Python, m: &PyModule) -> PyResult<()> { Ok(()) } +/// Normalizers Module +#[pymodule] +fn normalizers(_py: Python, m: &PyModule) -> PyResult<()> { + m.add_class::()?; + m.add_class::()?; + Ok(()) +} + /// Tokenizers Module #[pymodule] fn tokenizers(_py: Python, m: &PyModule) -> PyResult<()> { @@ -63,6 +72,7 @@ fn tokenizers(_py: Python, m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pymodule!(pre_tokenizers))?; m.add_wrapped(wrap_pymodule!(decoders))?; m.add_wrapped(wrap_pymodule!(processors))?; + m.add_wrapped(wrap_pymodule!(normalizers))?; m.add_wrapped(wrap_pymodule!(trainers))?; Ok(()) } diff --git a/bindings/python/src/normalizers.rs b/bindings/python/src/normalizers.rs new file mode 100644 index 00000000..1a93ed5a --- /dev/null +++ b/bindings/python/src/normalizers.rs @@ -0,0 +1,48 @@ +extern crate tokenizers as tk; + +use super::error::{PyError, ToPyResult}; +use super::utils::Container; +use pyo3::prelude::*; +use pyo3::types::*; +use tk::tokenizer::Result; + +#[pyclass(dict)] +pub struct Normalizer { + pub normalizer: Container, +} + +#[pyclass] +pub struct BertNormalizer {} +#[pymethods] +impl BertNormalizer { + #[staticmethod] + #[args(kwargs = "**")] + fn new(kwargs: Option<&PyDict>) -> PyResult { + let mut clean_text = true; + let mut handle_chinese_chars = true; + let mut strip_accents = true; + let mut lowercase = true; + + if let Some(kwargs) = kwargs { + for (key, value) in kwargs { + let key: &str = key.extract()?; + match key { + "clean_text" => clean_text = value.extract()?, + "handle_chinese_chars" => handle_chinese_chars = value.extract()?, + "strip_accents" => strip_accents = value.extract()?, + "lowercase" => lowercase = value.extract()?, + _ => println!("Ignored unknown kwargs option {}", key), + } + } + } + + Ok(Normalizer { + normalizer: Container::Owned(Box::new(tk::normalizers::bert::BertNormalizer::new( + clean_text, + handle_chinese_chars, + strip_accents, + lowercase, + ))), + }) + } +} diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index 6962658e..d50f1c44 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -8,6 +8,7 @@ use super::decoders::Decoder; use super::encoding::Encoding; use super::error::{PyError, ToPyResult}; use super::models::Model; +use super::normalizers::Normalizer; use super::pre_tokenizers::PreTokenizer; use super::processors::PostProcessor; use super::trainers::Trainer; @@ -97,6 +98,17 @@ impl Tokenizer { } } + fn with_normalizer(&mut self, normalizer: &mut Normalizer) -> PyResult<()> { + if let Some(normalizer) = normalizer.normalizer.to_pointer() { + self.tokenizer.with_normalizer(normalizer); + Ok(()) + } else { + Err(exceptions::Exception::py_err( + "The Normalizer is already being used in another Tokenizer", + )) + } + } + #[args(kwargs = "**")] fn with_truncation(&mut self, max_length: usize, kwargs: Option<&PyDict>) -> PyResult<()> { let mut stride = 0; diff --git a/bindings/python/tokenizers/__init__.py b/bindings/python/tokenizers/__init__.py index 8a39e493..a20eca31 100644 --- a/bindings/python/tokenizers/__init__.py +++ b/bindings/python/tokenizers/__init__.py @@ -1,3 +1,3 @@ __version__ = "0.0.11" -from .tokenizers import Tokenizer, models, decoders, pre_tokenizers, trainers, processors +from .tokenizers import Tokenizer, models, decoders, pre_tokenizers, trainers, processors, normalizers