Python - Update bindings

This commit is contained in:
Anthony MOI
2020-03-16 10:28:34 -04:00
parent e0cfad5102
commit 60a4fb35f4
3 changed files with 22 additions and 173 deletions

View File

@ -1,115 +1,11 @@
extern crate tokenizers as tk;
use crate::error::PyError;
use pyo3::exceptions;
use pyo3::prelude::*;
use pyo3::types::*;
use pyo3::{PyMappingProtocol, PyObjectProtocol, PySequenceProtocol};
use pyo3::{PyObjectProtocol, PySequenceProtocol};
use tk::tokenizer::PaddingDirection;
fn get_range(item: PyObject, max_len: usize) -> PyResult<std::ops::Range<usize>> {
let gil = Python::acquire_gil();
let py = gil.python();
let slice = if let Ok(index) = item.extract::<isize>(py) {
if index >= max_len as isize || index < -(max_len as isize) {
Err(exceptions::IndexError::py_err("Index out of bounds"))
} else {
Ok(if index == -1 {
PySlice::new(py, index, max_len as isize, 1)
} else {
PySlice::new(py, index, index + 1, 1)
})
}
} else if let Ok(slice) = item.cast_as::<PySlice>(py) {
Ok(slice)
} else if let Ok(offset) = item.cast_as::<PyTuple>(py) {
if offset.len() == 2 {
let start = offset.get_item(0).extract::<isize>()?;
let end = offset.get_item(1).extract::<isize>()?;
Ok(PySlice::new(py, start, end, 1))
} else {
Err(exceptions::TypeError::py_err("Expected Tuple[int, int]"))
}
} else {
Err(exceptions::TypeError::py_err(
"Expected number or slice or Tuple[int, int]",
))
}?;
// Find out range from the slice
let len: std::os::raw::c_long = (max_len as i32) as _;
let PySliceIndices { start, stop, .. } = slice.indices(len)?;
Ok(start as usize..stop as usize)
}
enum IndexableStringType {
Original,
Normalized,
}
#[pyclass(dict)]
pub struct IndexableString {
s: tk::tokenizer::NormalizedString,
t: IndexableStringType,
}
#[pymethods]
impl IndexableString {
fn offsets(&self, item: PyObject) -> PyResult<Option<(usize, usize)>> {
let range = get_range(item, self.s.len())?;
match self.t {
IndexableStringType::Original => Ok(self
.s
.get_original_offsets(range)
.map(|range| (range.start, range.end))),
IndexableStringType::Normalized => Ok(Some((range.start, range.end))),
}
}
}
#[pyproto]
impl PyObjectProtocol for IndexableString {
fn __repr__(&self) -> PyResult<String> {
Ok(match self.t {
IndexableStringType::Original => self.s.get_original().to_owned(),
IndexableStringType::Normalized => self.s.get().to_owned(),
})
}
fn __str__(&self) -> PyResult<String> {
Ok(match self.t {
IndexableStringType::Original => self.s.get_original().to_owned(),
IndexableStringType::Normalized => self.s.get().to_owned(),
})
}
}
#[pyproto]
impl PyMappingProtocol for IndexableString {
fn __getitem__(&self, item: PyObject) -> PyResult<String> {
// Find out the range
let range = get_range(item, self.s.len())?;
// Get the range from the relevant string
let s = match self.t {
IndexableStringType::Original => self.s.get_range_original(range),
IndexableStringType::Normalized => self.s.get_range(range),
};
s.map(|s| s.to_owned())
.ok_or_else(|| exceptions::IndexError::py_err("Wrong offsets"))
}
fn __len__(self) -> PyResult<usize> {
Ok(match self.t {
IndexableStringType::Original => self.s.len_original(),
IndexableStringType::Normalized => self.s.len(),
})
}
}
#[pyclass(dict)]
#[repr(transparent)]
pub struct Encoding {
@ -127,7 +23,7 @@ impl PyObjectProtocol for Encoding {
fn __repr__(&self) -> PyResult<String> {
Ok(format!(
"Encoding(num_tokens={}, attributes=[ids, type_ids, tokens, offsets, \
attention_mask, special_tokens_mask, overflowing, original_str, normalized_str])",
attention_mask, special_tokens_mask, overflowing])",
self.encoding.get_ids().len()
))
}
@ -142,50 +38,6 @@ impl PySequenceProtocol for Encoding {
#[pymethods]
impl Encoding {
#[getter]
fn get_normalized_str(&self) -> IndexableString {
IndexableString {
s: self.encoding.get_normalized().clone(),
t: IndexableStringType::Normalized,
}
}
#[getter]
fn get_original_str(&self) -> IndexableString {
IndexableString {
s: self.encoding.get_normalized().clone(),
t: IndexableStringType::Original,
}
}
#[args(kwargs = "**")]
fn get_range(
&self,
range: (usize, usize),
kwargs: Option<&PyDict>,
) -> PyResult<Option<String>> {
let mut original = false;
if let Some(kwargs) = kwargs {
if let Some(koriginal) = kwargs.get_item("original") {
original = koriginal.extract()?;
}
}
if original {
Ok(self
.encoding
.get_normalized()
.get_range_original(range.0..range.1)
.map(|s| s.to_owned()))
} else {
Ok(self
.encoding
.get_normalized()
.get_range(range.0..range.1)
.map(|s| s.to_owned()))
}
}
#[getter]
fn get_ids(&self) -> Vec<u32> {
self.encoding.get_ids().to_vec()

View File

@ -159,6 +159,15 @@ impl Tokenizer {
self.tokenizer.with_padding(None);
}
fn normalize(&self, sentence: &str) -> PyResult<String> {
ToPyResult(
self.tokenizer
.normalize(sentence)
.map(|s| s.get().to_owned()),
)
.into()
}
#[args(add_special_tokens = true)]
fn encode(
&self,

View File

@ -16,32 +16,9 @@ from typing import Optional, Union, List, Tuple
Offsets = Tuple[int, int]
class IndexableString:
"""
Works almost like a `str`, but allows indexing on offsets
provided on an `Encoding`
"""
def offsets(self, offsets: Tuple[int, int]) -> Optional[Tuple[int, int]]:
""" Convert the Encoding's offsets to the current string.
`Encoding` provides a list of offsets that are actually offsets to the Normalized
version of text. Calling this method with the offsets provided by `Encoding` will make
sure that said offsets can be used to index the `str` directly.
"""
pass
class Encoding:
""" An Encoding as returned by the Tokenizer """
@property
def normalized_str(self) -> IndexableString:
""" The normalized string """
pass
@property
def original_str(self) -> IndexableString:
""" The original string """
pass
@property
def ids(self) -> List[int]:
""" The tokenized ids """
@ -244,6 +221,17 @@ class Tokenizer:
def no_padding(self):
""" Disable padding """
pass
def normalize(self, sequence: str) -> str:
""" Normalize the given sequence
Args:
sequence: str:
The sequence to normalize
Returns:
The normalized string
"""
pass
def encode(
self, sequence: str, pair: Optional[str] = None, add_special_tokens: bool = True
) -> Encoding: