Python - Extract single pre-tokenized inputs from np.array

This commit is contained in:
Anthony MOI
2020-08-20 23:54:49 -04:00
committed by Anthony MOI
parent d919d68889
commit 14adf18e5b
3 changed files with 167 additions and 3 deletions

View File

@ -1,6 +1,7 @@
use std::collections::HashMap;
use std::sync::Arc;
use numpy::PyArray1;
use pyo3::exceptions;
use pyo3::prelude::*;
use pyo3::types::*;
@ -188,6 +189,90 @@ impl<'s> From<TextInputSequence<'s>> for tk::InputSequence<'s> {
}
}
struct PyArrayUnicode(Vec<String>);
impl FromPyObject<'_> for PyArrayUnicode {
fn extract(ob: &PyAny) -> PyResult<Self> {
let array = ob.downcast::<PyArray1<u8>>()?;
let arr = array.as_array_ptr();
let (type_num, elsize, alignment, data) = unsafe {
let desc = (*arr).descr;
(
(*desc).type_num,
(*desc).elsize as usize,
(*desc).alignment as usize,
(*arr).data,
)
};
let n_elem = array.shape()[0];
// type_num == 19 => Unicode
if type_num != 19 {
return Err(exceptions::TypeError::py_err("Expected a np.array[str]"));
}
unsafe {
let all_bytes = std::slice::from_raw_parts(data as *const u8, elsize * n_elem);
let seq = (0..n_elem)
.map(|i| {
let bytes = &all_bytes[i * elsize..(i + 1) * elsize];
let unicode = pyo3::ffi::PyUnicode_FromUnicode(
bytes.as_ptr() as *const _,
elsize as isize / alignment as isize,
);
let gil = Python::acquire_gil();
let py = gil.python();
let obj = PyObject::from_owned_ptr(py, unicode);
let s = obj.cast_as::<PyString>(py)?;
Ok(s.to_string()?.trim_matches(char::from(0)).to_owned())
})
.collect::<PyResult<Vec<_>>>()?;
Ok(Self(seq))
}
}
}
impl From<PyArrayUnicode> for tk::InputSequence<'_> {
fn from(s: PyArrayUnicode) -> Self {
s.0.into()
}
}
struct PyArrayStr(Vec<String>);
impl FromPyObject<'_> for PyArrayStr {
fn extract(ob: &PyAny) -> PyResult<Self> {
let array = ob.downcast::<PyArray1<u8>>()?;
let arr = array.as_array_ptr();
let (type_num, data) = unsafe { ((*(*arr).descr).type_num, (*arr).data) };
let n_elem = array.shape()[0];
if type_num != 17 {
return Err(exceptions::TypeError::py_err("Expected a np.array[str]"));
}
unsafe {
let objects = std::slice::from_raw_parts(data as *const PyObject, n_elem);
let seq = objects
.into_iter()
.map(|obj| {
let gil = Python::acquire_gil();
let py = gil.python();
let s = obj.cast_as::<PyString>(py)?;
Ok(s.to_string()?.into_owned())
})
.collect::<PyResult<Vec<_>>>()?;
Ok(Self(seq))
}
}
}
impl From<PyArrayStr> for tk::InputSequence<'_> {
fn from(s: PyArrayStr) -> Self {
s.0.into()
}
}
struct PreTokenizedInputSequence<'s>(tk::InputSequence<'s>);
impl<'s> FromPyObject<'s> for PreTokenizedInputSequence<'s> {
fn extract(ob: &'s PyAny) -> PyResult<Self> {
@ -195,11 +280,15 @@ impl<'s> FromPyObject<'s> for PreTokenizedInputSequence<'s> {
"PreTokenizedInputSequence must be Union[List[str], Tuple[str]]",
);
if let Ok(s) = ob.downcast::<PyList>() {
let seq = s.extract::<Vec<String>>().map_err(|_| err)?;
if let Ok(seq) = ob.extract::<PyArrayUnicode>() {
Ok(Self(seq.into()))
} else if let Ok(seq) = ob.extract::<PyArrayStr>() {
Ok(Self(seq.into()))
} else if let Ok(s) = ob.downcast::<PyList>() {
let seq = s.extract::<Vec<&str>>().map_err(|_| err)?;
Ok(Self(seq.into()))
} else if let Ok(s) = ob.downcast::<PyTuple>() {
let seq = s.extract::<Vec<String>>().map_err(|_| err)?;
let seq = s.extract::<Vec<&str>>().map_err(|_| err)?;
Ok(Self(seq.into()))
} else {
Err(err)