mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Python - Update bindings
This commit is contained in:
@ -1,115 +1,11 @@
|
||||
extern crate tokenizers as tk;
|
||||
|
||||
use crate::error::PyError;
|
||||
use pyo3::exceptions;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::*;
|
||||
use pyo3::{PyMappingProtocol, PyObjectProtocol, PySequenceProtocol};
|
||||
use pyo3::{PyObjectProtocol, PySequenceProtocol};
|
||||
use tk::tokenizer::PaddingDirection;
|
||||
|
||||
fn get_range(item: PyObject, max_len: usize) -> PyResult<std::ops::Range<usize>> {
|
||||
let gil = Python::acquire_gil();
|
||||
let py = gil.python();
|
||||
|
||||
let slice = if let Ok(index) = item.extract::<isize>(py) {
|
||||
if index >= max_len as isize || index < -(max_len as isize) {
|
||||
Err(exceptions::IndexError::py_err("Index out of bounds"))
|
||||
} else {
|
||||
Ok(if index == -1 {
|
||||
PySlice::new(py, index, max_len as isize, 1)
|
||||
} else {
|
||||
PySlice::new(py, index, index + 1, 1)
|
||||
})
|
||||
}
|
||||
} else if let Ok(slice) = item.cast_as::<PySlice>(py) {
|
||||
Ok(slice)
|
||||
} else if let Ok(offset) = item.cast_as::<PyTuple>(py) {
|
||||
if offset.len() == 2 {
|
||||
let start = offset.get_item(0).extract::<isize>()?;
|
||||
let end = offset.get_item(1).extract::<isize>()?;
|
||||
Ok(PySlice::new(py, start, end, 1))
|
||||
} else {
|
||||
Err(exceptions::TypeError::py_err("Expected Tuple[int, int]"))
|
||||
}
|
||||
} else {
|
||||
Err(exceptions::TypeError::py_err(
|
||||
"Expected number or slice or Tuple[int, int]",
|
||||
))
|
||||
}?;
|
||||
|
||||
// Find out range from the slice
|
||||
let len: std::os::raw::c_long = (max_len as i32) as _;
|
||||
let PySliceIndices { start, stop, .. } = slice.indices(len)?;
|
||||
|
||||
Ok(start as usize..stop as usize)
|
||||
}
|
||||
|
||||
enum IndexableStringType {
|
||||
Original,
|
||||
Normalized,
|
||||
}
|
||||
|
||||
#[pyclass(dict)]
|
||||
pub struct IndexableString {
|
||||
s: tk::tokenizer::NormalizedString,
|
||||
t: IndexableStringType,
|
||||
}
|
||||
#[pymethods]
|
||||
impl IndexableString {
|
||||
fn offsets(&self, item: PyObject) -> PyResult<Option<(usize, usize)>> {
|
||||
let range = get_range(item, self.s.len())?;
|
||||
|
||||
match self.t {
|
||||
IndexableStringType::Original => Ok(self
|
||||
.s
|
||||
.get_original_offsets(range)
|
||||
.map(|range| (range.start, range.end))),
|
||||
IndexableStringType::Normalized => Ok(Some((range.start, range.end))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[pyproto]
|
||||
impl PyObjectProtocol for IndexableString {
|
||||
fn __repr__(&self) -> PyResult<String> {
|
||||
Ok(match self.t {
|
||||
IndexableStringType::Original => self.s.get_original().to_owned(),
|
||||
IndexableStringType::Normalized => self.s.get().to_owned(),
|
||||
})
|
||||
}
|
||||
|
||||
fn __str__(&self) -> PyResult<String> {
|
||||
Ok(match self.t {
|
||||
IndexableStringType::Original => self.s.get_original().to_owned(),
|
||||
IndexableStringType::Normalized => self.s.get().to_owned(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[pyproto]
|
||||
impl PyMappingProtocol for IndexableString {
|
||||
fn __getitem__(&self, item: PyObject) -> PyResult<String> {
|
||||
// Find out the range
|
||||
let range = get_range(item, self.s.len())?;
|
||||
|
||||
// Get the range from the relevant string
|
||||
let s = match self.t {
|
||||
IndexableStringType::Original => self.s.get_range_original(range),
|
||||
IndexableStringType::Normalized => self.s.get_range(range),
|
||||
};
|
||||
|
||||
s.map(|s| s.to_owned())
|
||||
.ok_or_else(|| exceptions::IndexError::py_err("Wrong offsets"))
|
||||
}
|
||||
|
||||
fn __len__(self) -> PyResult<usize> {
|
||||
Ok(match self.t {
|
||||
IndexableStringType::Original => self.s.len_original(),
|
||||
IndexableStringType::Normalized => self.s.len(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass(dict)]
|
||||
#[repr(transparent)]
|
||||
pub struct Encoding {
|
||||
@ -127,7 +23,7 @@ impl PyObjectProtocol for Encoding {
|
||||
fn __repr__(&self) -> PyResult<String> {
|
||||
Ok(format!(
|
||||
"Encoding(num_tokens={}, attributes=[ids, type_ids, tokens, offsets, \
|
||||
attention_mask, special_tokens_mask, overflowing, original_str, normalized_str])",
|
||||
attention_mask, special_tokens_mask, overflowing])",
|
||||
self.encoding.get_ids().len()
|
||||
))
|
||||
}
|
||||
@ -142,50 +38,6 @@ impl PySequenceProtocol for Encoding {
|
||||
|
||||
#[pymethods]
|
||||
impl Encoding {
|
||||
#[getter]
|
||||
fn get_normalized_str(&self) -> IndexableString {
|
||||
IndexableString {
|
||||
s: self.encoding.get_normalized().clone(),
|
||||
t: IndexableStringType::Normalized,
|
||||
}
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_original_str(&self) -> IndexableString {
|
||||
IndexableString {
|
||||
s: self.encoding.get_normalized().clone(),
|
||||
t: IndexableStringType::Original,
|
||||
}
|
||||
}
|
||||
|
||||
#[args(kwargs = "**")]
|
||||
fn get_range(
|
||||
&self,
|
||||
range: (usize, usize),
|
||||
kwargs: Option<&PyDict>,
|
||||
) -> PyResult<Option<String>> {
|
||||
let mut original = false;
|
||||
if let Some(kwargs) = kwargs {
|
||||
if let Some(koriginal) = kwargs.get_item("original") {
|
||||
original = koriginal.extract()?;
|
||||
}
|
||||
}
|
||||
|
||||
if original {
|
||||
Ok(self
|
||||
.encoding
|
||||
.get_normalized()
|
||||
.get_range_original(range.0..range.1)
|
||||
.map(|s| s.to_owned()))
|
||||
} else {
|
||||
Ok(self
|
||||
.encoding
|
||||
.get_normalized()
|
||||
.get_range(range.0..range.1)
|
||||
.map(|s| s.to_owned()))
|
||||
}
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_ids(&self) -> Vec<u32> {
|
||||
self.encoding.get_ids().to_vec()
|
||||
|
@ -159,6 +159,15 @@ impl Tokenizer {
|
||||
self.tokenizer.with_padding(None);
|
||||
}
|
||||
|
||||
fn normalize(&self, sentence: &str) -> PyResult<String> {
|
||||
ToPyResult(
|
||||
self.tokenizer
|
||||
.normalize(sentence)
|
||||
.map(|s| s.get().to_owned()),
|
||||
)
|
||||
.into()
|
||||
}
|
||||
|
||||
#[args(add_special_tokens = true)]
|
||||
fn encode(
|
||||
&self,
|
||||
|
@ -16,32 +16,9 @@ from typing import Optional, Union, List, Tuple
|
||||
|
||||
Offsets = Tuple[int, int]
|
||||
|
||||
class IndexableString:
|
||||
"""
|
||||
Works almost like a `str`, but allows indexing on offsets
|
||||
provided on an `Encoding`
|
||||
"""
|
||||
|
||||
def offsets(self, offsets: Tuple[int, int]) -> Optional[Tuple[int, int]]:
|
||||
""" Convert the Encoding's offsets to the current string.
|
||||
|
||||
`Encoding` provides a list of offsets that are actually offsets to the Normalized
|
||||
version of text. Calling this method with the offsets provided by `Encoding` will make
|
||||
sure that said offsets can be used to index the `str` directly.
|
||||
"""
|
||||
pass
|
||||
|
||||
class Encoding:
|
||||
""" An Encoding as returned by the Tokenizer """
|
||||
|
||||
@property
|
||||
def normalized_str(self) -> IndexableString:
|
||||
""" The normalized string """
|
||||
pass
|
||||
@property
|
||||
def original_str(self) -> IndexableString:
|
||||
""" The original string """
|
||||
pass
|
||||
@property
|
||||
def ids(self) -> List[int]:
|
||||
""" The tokenized ids """
|
||||
@ -244,6 +221,17 @@ class Tokenizer:
|
||||
def no_padding(self):
|
||||
""" Disable padding """
|
||||
pass
|
||||
def normalize(self, sequence: str) -> str:
|
||||
""" Normalize the given sequence
|
||||
|
||||
Args:
|
||||
sequence: str:
|
||||
The sequence to normalize
|
||||
|
||||
Returns:
|
||||
The normalized string
|
||||
"""
|
||||
pass
|
||||
def encode(
|
||||
self, sequence: str, pair: Optional[str] = None, add_special_tokens: bool = True
|
||||
) -> Encoding:
|
||||
|
Reference in New Issue
Block a user