mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-23 00:35:35 +00:00
Python - Extract single pre-tokenized inputs from np.array
This commit is contained in:
@ -1,6 +1,7 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use numpy::PyArray1;
|
||||
use pyo3::exceptions;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::*;
|
||||
@ -188,6 +189,90 @@ impl<'s> From<TextInputSequence<'s>> for tk::InputSequence<'s> {
|
||||
}
|
||||
}
|
||||
|
||||
struct PyArrayUnicode(Vec<String>);
|
||||
impl FromPyObject<'_> for PyArrayUnicode {
|
||||
fn extract(ob: &PyAny) -> PyResult<Self> {
|
||||
let array = ob.downcast::<PyArray1<u8>>()?;
|
||||
let arr = array.as_array_ptr();
|
||||
let (type_num, elsize, alignment, data) = unsafe {
|
||||
let desc = (*arr).descr;
|
||||
(
|
||||
(*desc).type_num,
|
||||
(*desc).elsize as usize,
|
||||
(*desc).alignment as usize,
|
||||
(*arr).data,
|
||||
)
|
||||
};
|
||||
let n_elem = array.shape()[0];
|
||||
|
||||
// type_num == 19 => Unicode
|
||||
if type_num != 19 {
|
||||
return Err(exceptions::TypeError::py_err("Expected a np.array[str]"));
|
||||
}
|
||||
|
||||
unsafe {
|
||||
let all_bytes = std::slice::from_raw_parts(data as *const u8, elsize * n_elem);
|
||||
|
||||
let seq = (0..n_elem)
|
||||
.map(|i| {
|
||||
let bytes = &all_bytes[i * elsize..(i + 1) * elsize];
|
||||
let unicode = pyo3::ffi::PyUnicode_FromUnicode(
|
||||
bytes.as_ptr() as *const _,
|
||||
elsize as isize / alignment as isize,
|
||||
);
|
||||
let gil = Python::acquire_gil();
|
||||
let py = gil.python();
|
||||
let obj = PyObject::from_owned_ptr(py, unicode);
|
||||
let s = obj.cast_as::<PyString>(py)?;
|
||||
Ok(s.to_string()?.trim_matches(char::from(0)).to_owned())
|
||||
})
|
||||
.collect::<PyResult<Vec<_>>>()?;
|
||||
|
||||
Ok(Self(seq))
|
||||
}
|
||||
}
|
||||
}
|
||||
impl From<PyArrayUnicode> for tk::InputSequence<'_> {
|
||||
fn from(s: PyArrayUnicode) -> Self {
|
||||
s.0.into()
|
||||
}
|
||||
}
|
||||
|
||||
struct PyArrayStr(Vec<String>);
|
||||
impl FromPyObject<'_> for PyArrayStr {
|
||||
fn extract(ob: &PyAny) -> PyResult<Self> {
|
||||
let array = ob.downcast::<PyArray1<u8>>()?;
|
||||
let arr = array.as_array_ptr();
|
||||
let (type_num, data) = unsafe { ((*(*arr).descr).type_num, (*arr).data) };
|
||||
let n_elem = array.shape()[0];
|
||||
|
||||
if type_num != 17 {
|
||||
return Err(exceptions::TypeError::py_err("Expected a np.array[str]"));
|
||||
}
|
||||
|
||||
unsafe {
|
||||
let objects = std::slice::from_raw_parts(data as *const PyObject, n_elem);
|
||||
|
||||
let seq = objects
|
||||
.into_iter()
|
||||
.map(|obj| {
|
||||
let gil = Python::acquire_gil();
|
||||
let py = gil.python();
|
||||
let s = obj.cast_as::<PyString>(py)?;
|
||||
Ok(s.to_string()?.into_owned())
|
||||
})
|
||||
.collect::<PyResult<Vec<_>>>()?;
|
||||
|
||||
Ok(Self(seq))
|
||||
}
|
||||
}
|
||||
}
|
||||
impl From<PyArrayStr> for tk::InputSequence<'_> {
|
||||
fn from(s: PyArrayStr) -> Self {
|
||||
s.0.into()
|
||||
}
|
||||
}
|
||||
|
||||
struct PreTokenizedInputSequence<'s>(tk::InputSequence<'s>);
|
||||
impl<'s> FromPyObject<'s> for PreTokenizedInputSequence<'s> {
|
||||
fn extract(ob: &'s PyAny) -> PyResult<Self> {
|
||||
@ -195,11 +280,15 @@ impl<'s> FromPyObject<'s> for PreTokenizedInputSequence<'s> {
|
||||
"PreTokenizedInputSequence must be Union[List[str], Tuple[str]]",
|
||||
);
|
||||
|
||||
if let Ok(s) = ob.downcast::<PyList>() {
|
||||
let seq = s.extract::<Vec<String>>().map_err(|_| err)?;
|
||||
if let Ok(seq) = ob.extract::<PyArrayUnicode>() {
|
||||
Ok(Self(seq.into()))
|
||||
} else if let Ok(seq) = ob.extract::<PyArrayStr>() {
|
||||
Ok(Self(seq.into()))
|
||||
} else if let Ok(s) = ob.downcast::<PyList>() {
|
||||
let seq = s.extract::<Vec<&str>>().map_err(|_| err)?;
|
||||
Ok(Self(seq.into()))
|
||||
} else if let Ok(s) = ob.downcast::<PyTuple>() {
|
||||
let seq = s.extract::<Vec<String>>().map_err(|_| err)?;
|
||||
let seq = s.extract::<Vec<&str>>().map_err(|_| err)?;
|
||||
Ok(Self(seq.into()))
|
||||
} else {
|
||||
Err(err)
|
||||
|
Reference in New Issue
Block a user