diff --git a/bindings/python/src/encoding.rs b/bindings/python/src/encoding.rs index b979327e..948f7f2d 100644 --- a/bindings/python/src/encoding.rs +++ b/bindings/python/src/encoding.rs @@ -1,6 +1,7 @@ extern crate tokenizers as tk; use crate::error::PyError; +use crate::normalized_string::NormalizedString; use pyo3::prelude::*; use pyo3::types::*; use tk::tokenizer::PaddingDirection; @@ -20,13 +21,8 @@ impl Encoding { #[pymethods] impl Encoding { #[getter] - fn get_original(&self) -> String { - self.encoding.get_normalized().get_original().to_owned() - } - - #[getter] - fn get_normalized(&self) -> String { - self.encoding.get_normalized().get().to_owned() + fn get_normalized(&self) -> NormalizedString { + NormalizedString::new(self.encoding.get_normalized().clone()) } #[args(kwargs = "**")] diff --git a/bindings/python/src/lib.rs b/bindings/python/src/lib.rs index 413e05a2..1f3ef991 100644 --- a/bindings/python/src/lib.rs +++ b/bindings/python/src/lib.rs @@ -2,6 +2,7 @@ mod decoders; mod encoding; mod error; mod models; +mod normalized_string; mod normalizers; mod pre_tokenizers; mod processors; @@ -70,6 +71,8 @@ fn normalizers(_py: Python, m: &PyModule) -> PyResult<()> { #[pymodule] fn tokenizers(_py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; + m.add_class::()?; + m.add_class::()?; m.add_wrapped(wrap_pymodule!(models))?; m.add_wrapped(wrap_pymodule!(pre_tokenizers))?; m.add_wrapped(wrap_pymodule!(decoders))?; diff --git a/bindings/python/src/normalized_string.rs b/bindings/python/src/normalized_string.rs new file mode 100644 index 00000000..f1e7328e --- /dev/null +++ b/bindings/python/src/normalized_string.rs @@ -0,0 +1,35 @@ +extern crate tokenizers as tk; + +use pyo3::prelude::*; + +#[pyclass] +#[repr(transparent)] +pub struct NormalizedString { + s: tk::tokenizer::NormalizedString, +} +impl NormalizedString { + pub fn new(s: tk::tokenizer::NormalizedString) -> NormalizedString { + NormalizedString { s } + } +} + +#[pymethods] +impl NormalizedString { + #[getter] + fn get_original(&self) -> String { + self.s.get_original().to_owned() + } + + #[getter] + fn get_normalized(&self) -> String { + self.s.get().to_owned() + } + + fn get_range(&self, start: usize, end: usize) -> Option { + self.s.get_range(start..end).map(|s| s.to_owned()) + } + + fn get_range_original(&self, start: usize, end: usize) -> Option { + self.s.get_range_original(start..end).map(|s| s.to_owned()) + } +} diff --git a/bindings/python/tokenizers/__init__.pyi b/bindings/python/tokenizers/__init__.pyi index 91408555..e67431f6 100644 --- a/bindings/python/tokenizers/__init__.pyi +++ b/bindings/python/tokenizers/__init__.pyi @@ -9,10 +9,62 @@ from typing import Optional, Union, List, Tuple Offsets = Tuple[int, int] +class NormalizedString: + """ A NormalizedString produced during normalization """ + + @property + def original(self) -> str: + """ The original string """ + pass + + @property + def normalized(self) -> str: + """ The normalized string """ + pass + + def get_range(self, start: int, end: int) -> Optional[str]: + """ Return a range of the normalized string, if the bounds are correct + + Args: + start: int: + The starting offset in the string + + end: int: + The ending offset in the string + + Returns: + The substring if the bounds are correct + """ + pass + + def get_range_original(self, start: int, end: int) -> Optional[str]: + """ Return a range of the original string, if the bounds are correct + + The given bounds are supposed to be after-normalization-offsets. + Provided with the `Encoding.offsets` associated with an `Encoding.ids` unit, + this method will return the part of the original string corresponding to the id. + + Args: + start: int: + The starting offset in the normalized string + + end: int: + The ending offset in the normalized string + + Returns: + The substring if the bounds are correct + """ + pass + class Encoding: """ An Encoding as returned by the Tokenizer """ + @property + def normalized(self) -> NormalizedString: + """ The NormalizedString """ + pass + @property def ids(self) -> List[int]: """ The tokenized ids """