Python - IndexableString in Encoding

This commit is contained in:
Anthony MOI
2020-01-08 00:06:57 -05:00
parent fb250fd7fc
commit 88711d5717
4 changed files with 104 additions and 91 deletions

View File

@ -1,12 +1,87 @@
extern crate tokenizers as tk;
use crate::error::PyError;
use crate::normalized_string::NormalizedString;
use pyo3::exceptions;
use pyo3::prelude::*;
use pyo3::types::*;
use pyo3::PyObjectProtocol;
use pyo3::{PyMappingProtocol, PyObjectProtocol};
use tk::tokenizer::PaddingDirection;
enum IndexableStringType {
Original,
Normalized,
}
#[pyclass(dict)]
pub struct IndexableString {
s: tk::tokenizer::NormalizedString,
t: IndexableStringType,
}
#[pymethods]
impl IndexableString {}
#[pyproto]
impl PyObjectProtocol for IndexableString {
fn __repr__(&self) -> PyResult<String> {
Ok(match self.t {
IndexableStringType::Original => self.s.get_original().to_owned(),
IndexableStringType::Normalized => self.s.get().to_owned(),
})
}
fn __str__(&self) -> PyResult<String> {
Ok(match self.t {
IndexableStringType::Original => self.s.get_original().to_owned(),
IndexableStringType::Normalized => self.s.get().to_owned(),
})
}
}
#[pyproto]
impl PyMappingProtocol for IndexableString {
fn __getitem__(&self, item: PyObject) -> PyResult<String> {
let gil = Python::acquire_gil();
let py = gil.python();
// Make a slice from a number or get a slice directly
let slice = if let Ok(index) = item.extract::<isize>(py) {
if index >= self.s.len() as isize || index < -(self.s.len() as isize) {
Err(exceptions::IndexError::py_err("Index out of bounds"))
} else {
Ok(if index == -1 {
PySlice::new(py, index, self.s.len() as isize, 1)
} else {
PySlice::new(py, index, index + 1, 1)
})
}
} else if let Ok(slice) = item.cast_as::<PySlice>(py) {
Ok(slice)
} else {
Err(exceptions::TypeError::py_err("Expected number or slice"))
}?;
// Find out range from the slice
let PySliceIndices { start, stop, .. } = slice.indices(self.s.len() as i64)?;
let range = start as usize..stop as usize;
// Get the range from the relevant string
let s = match self.t {
IndexableStringType::Original => self.s.get_range(range),
IndexableStringType::Normalized => self.s.get_range_original(range),
};
s.map(|s| s.to_owned())
.ok_or_else(|| exceptions::IndexError::py_err("Wrong offsets"))
}
fn __len__(self) -> PyResult<usize> {
Ok(match self.t {
IndexableStringType::Original => self.s.len_original(),
IndexableStringType::Normalized => self.s.len(),
})
}
}
#[pyclass(dict)]
#[repr(transparent)]
pub struct Encoding {
@ -24,7 +99,7 @@ impl PyObjectProtocol for Encoding {
fn __repr__(&self) -> PyResult<String> {
Ok(format!(
"Encoding(num_tokens={}, attributs=[ids, type_ids, tokens, offsets, \
attention_mask, special_tokens_mask, overflowing])",
attention_mask, special_tokens_mask, overflowing, original_str, normalized_str])",
self.encoding.get_ids().len()
))
}
@ -33,8 +108,19 @@ impl PyObjectProtocol for Encoding {
#[pymethods]
impl Encoding {
#[getter]
fn get_normalized(&self) -> NormalizedString {
NormalizedString::new(self.encoding.get_normalized().clone())
fn get_normalized_str(&self) -> IndexableString {
IndexableString {
s: self.encoding.get_normalized().clone(),
t: IndexableStringType::Normalized,
}
}
#[getter]
fn get_original_str(&self) -> IndexableString {
IndexableString {
s: self.encoding.get_normalized().clone(),
t: IndexableStringType::Original,
}
}
#[args(kwargs = "**")]