mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-23 00:35:35 +00:00
Python - IndexableString in Encoding
This commit is contained in:
@ -1,12 +1,87 @@
|
||||
extern crate tokenizers as tk;
|
||||
|
||||
use crate::error::PyError;
|
||||
use crate::normalized_string::NormalizedString;
|
||||
use pyo3::exceptions;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::*;
|
||||
use pyo3::PyObjectProtocol;
|
||||
use pyo3::{PyMappingProtocol, PyObjectProtocol};
|
||||
use tk::tokenizer::PaddingDirection;
|
||||
|
||||
enum IndexableStringType {
|
||||
Original,
|
||||
Normalized,
|
||||
}
|
||||
|
||||
#[pyclass(dict)]
|
||||
pub struct IndexableString {
|
||||
s: tk::tokenizer::NormalizedString,
|
||||
t: IndexableStringType,
|
||||
}
|
||||
#[pymethods]
|
||||
impl IndexableString {}
|
||||
|
||||
#[pyproto]
|
||||
impl PyObjectProtocol for IndexableString {
|
||||
fn __repr__(&self) -> PyResult<String> {
|
||||
Ok(match self.t {
|
||||
IndexableStringType::Original => self.s.get_original().to_owned(),
|
||||
IndexableStringType::Normalized => self.s.get().to_owned(),
|
||||
})
|
||||
}
|
||||
|
||||
fn __str__(&self) -> PyResult<String> {
|
||||
Ok(match self.t {
|
||||
IndexableStringType::Original => self.s.get_original().to_owned(),
|
||||
IndexableStringType::Normalized => self.s.get().to_owned(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[pyproto]
|
||||
impl PyMappingProtocol for IndexableString {
|
||||
fn __getitem__(&self, item: PyObject) -> PyResult<String> {
|
||||
let gil = Python::acquire_gil();
|
||||
let py = gil.python();
|
||||
|
||||
// Make a slice from a number or get a slice directly
|
||||
let slice = if let Ok(index) = item.extract::<isize>(py) {
|
||||
if index >= self.s.len() as isize || index < -(self.s.len() as isize) {
|
||||
Err(exceptions::IndexError::py_err("Index out of bounds"))
|
||||
} else {
|
||||
Ok(if index == -1 {
|
||||
PySlice::new(py, index, self.s.len() as isize, 1)
|
||||
} else {
|
||||
PySlice::new(py, index, index + 1, 1)
|
||||
})
|
||||
}
|
||||
} else if let Ok(slice) = item.cast_as::<PySlice>(py) {
|
||||
Ok(slice)
|
||||
} else {
|
||||
Err(exceptions::TypeError::py_err("Expected number or slice"))
|
||||
}?;
|
||||
|
||||
// Find out range from the slice
|
||||
let PySliceIndices { start, stop, .. } = slice.indices(self.s.len() as i64)?;
|
||||
let range = start as usize..stop as usize;
|
||||
|
||||
// Get the range from the relevant string
|
||||
let s = match self.t {
|
||||
IndexableStringType::Original => self.s.get_range(range),
|
||||
IndexableStringType::Normalized => self.s.get_range_original(range),
|
||||
};
|
||||
|
||||
s.map(|s| s.to_owned())
|
||||
.ok_or_else(|| exceptions::IndexError::py_err("Wrong offsets"))
|
||||
}
|
||||
|
||||
fn __len__(self) -> PyResult<usize> {
|
||||
Ok(match self.t {
|
||||
IndexableStringType::Original => self.s.len_original(),
|
||||
IndexableStringType::Normalized => self.s.len(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass(dict)]
|
||||
#[repr(transparent)]
|
||||
pub struct Encoding {
|
||||
@ -24,7 +99,7 @@ impl PyObjectProtocol for Encoding {
|
||||
fn __repr__(&self) -> PyResult<String> {
|
||||
Ok(format!(
|
||||
"Encoding(num_tokens={}, attributs=[ids, type_ids, tokens, offsets, \
|
||||
attention_mask, special_tokens_mask, overflowing])",
|
||||
attention_mask, special_tokens_mask, overflowing, original_str, normalized_str])",
|
||||
self.encoding.get_ids().len()
|
||||
))
|
||||
}
|
||||
@ -33,8 +108,19 @@ impl PyObjectProtocol for Encoding {
|
||||
#[pymethods]
|
||||
impl Encoding {
|
||||
#[getter]
|
||||
fn get_normalized(&self) -> NormalizedString {
|
||||
NormalizedString::new(self.encoding.get_normalized().clone())
|
||||
fn get_normalized_str(&self) -> IndexableString {
|
||||
IndexableString {
|
||||
s: self.encoding.get_normalized().clone(),
|
||||
t: IndexableStringType::Normalized,
|
||||
}
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_original_str(&self) -> IndexableString {
|
||||
IndexableString {
|
||||
s: self.encoding.get_normalized().clone(),
|
||||
t: IndexableStringType::Original,
|
||||
}
|
||||
}
|
||||
|
||||
#[args(kwargs = "**")]
|
||||
|
Reference in New Issue
Block a user