diff --git a/bindings/python/src/encoding.rs b/bindings/python/src/encoding.rs index 88f3dd79..487fc306 100644 --- a/bindings/python/src/encoding.rs +++ b/bindings/python/src/encoding.rs @@ -7,6 +7,43 @@ use pyo3::types::*; use pyo3::{PyMappingProtocol, PyObjectProtocol}; use tk::tokenizer::PaddingDirection; +fn get_range(item: PyObject, max_len: usize) -> PyResult> { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let slice = if let Ok(index) = item.extract::(py) { + if index >= max_len as isize || index < -(max_len as isize) { + Err(exceptions::IndexError::py_err("Index out of bounds")) + } else { + Ok(if index == -1 { + PySlice::new(py, index, max_len as isize, 1) + } else { + PySlice::new(py, index, index + 1, 1) + }) + } + } else if let Ok(slice) = item.cast_as::(py) { + Ok(slice) + } else if let Ok(offset) = item.cast_as::(py) { + if offset.len() == 2 { + let start = offset.get_item(0).extract::()?; + let end = offset.get_item(1).extract::()?; + Ok(PySlice::new(py, start, end, 1)) + } else { + Err(exceptions::TypeError::py_err("Expected Tuple[int, int]")) + } + } else { + Err(exceptions::TypeError::py_err( + "Expected number or slice or Tuple[int, int]", + )) + }?; + + // Find out range from the slice + let len: std::os::raw::c_long = (max_len as i32) as _; + let PySliceIndices { start, stop, .. } = slice.indices(len)?; + + Ok(start as usize..stop as usize) +} + enum IndexableStringType { Original, Normalized, @@ -18,7 +55,19 @@ pub struct IndexableString { t: IndexableStringType, } #[pymethods] -impl IndexableString {} +impl IndexableString { + fn offsets(&self, item: PyObject) -> PyResult> { + let range = get_range(item, self.s.len())?; + + match self.t { + IndexableStringType::Original => Ok(self + .s + .get_original_offsets(range) + .map(|range| (range.start, range.end))), + IndexableStringType::Normalized => Ok(Some((range.start, range.end))), + } + } +} #[pyproto] impl PyObjectProtocol for IndexableString { @@ -40,40 +89,8 @@ impl PyObjectProtocol for IndexableString { #[pyproto] impl PyMappingProtocol for IndexableString { fn __getitem__(&self, item: PyObject) -> PyResult { - let gil = Python::acquire_gil(); - let py = gil.python(); - - // Make a slice from a number or get a slice directly - let slice = if let Ok(index) = item.extract::(py) { - if index >= self.s.len() as isize || index < -(self.s.len() as isize) { - Err(exceptions::IndexError::py_err("Index out of bounds")) - } else { - Ok(if index == -1 { - PySlice::new(py, index, self.s.len() as isize, 1) - } else { - PySlice::new(py, index, index + 1, 1) - }) - } - } else if let Ok(slice) = item.cast_as::(py) { - Ok(slice) - } else if let Ok(offset) = item.cast_as::(py) { - if offset.len() == 2 { - let start = offset.get_item(0).extract::()?; - let end = offset.get_item(1).extract::()?; - Ok(PySlice::new(py, start, end, 1)) - } else { - Err(exceptions::TypeError::py_err("Expected Tuple[int, int]")) - } - } else { - Err(exceptions::TypeError::py_err( - "Expected number or slice or Tuple[int, int]", - )) - }?; - - // Find out range from the slice - let len: std::os::raw::c_long = (self.s.len() as i32) as _; - let PySliceIndices { start, stop, .. } = slice.indices(len)?; - let range = start as usize..stop as usize; + // Find out the range + let range = get_range(item, self.s.len())?; // Get the range from the relevant string let s = match self.t { diff --git a/tokenizers/src/tokenizer/normalizer.rs b/tokenizers/src/tokenizer/normalizer.rs index 1982e9fb..570e0624 100644 --- a/tokenizers/src/tokenizer/normalizer.rs +++ b/tokenizers/src/tokenizer/normalizer.rs @@ -44,6 +44,26 @@ impl NormalizedString { &self.original } + /// Return the range of the original string corresponding to the received range on the + /// normalized string. Returns None if out of bounds + pub fn get_original_offsets( + &self, + range: std::ops::Range, + ) -> Option> { + self.alignments + .get(range) + .map(|alignments| { + if alignments.is_empty() { + None + } else { + let start = alignments[0].0; + let end = alignments[alignments.len() - 1].1; + Some(start..end) + } + }) + .flatten() + } + fn get_range_of(&self, s: &str, range: std::ops::Range) -> Option { let len = s.chars().count(); if range.start >= len || range.end > len { @@ -75,17 +95,8 @@ impl NormalizedString { /// Return a range of the original string, using a range from the normalized string pub fn get_range_original(&self, range: std::ops::Range) -> Option { - self.alignments - .get(range) - .map(|alignments| { - if alignments.is_empty() { - None - } else { - let start = alignments[0].0; - let end = alignments[alignments.len() - 1].1; - self.get_range_of(&self.original, start..end) - } - }) + self.get_original_offsets(range) + .map(|range| self.get_range_of(&self.original, range)) .flatten() }