Update PyO3 (#426)

This commit is contained in:
Anthony MOI
2020-09-22 12:00:20 -04:00
committed by GitHub
parent 8e220dbdd4
commit 940f8bd8fa
13 changed files with 156 additions and 178 deletions

View File

@ -30,20 +30,16 @@ impl PyPostProcessor {
let base = self.clone();
let gil = Python::acquire_gil();
let py = gil.python();
match self.processor.as_ref() {
PostProcessorWrapper::ByteLevel(_) => {
Py::new(py, (PyByteLevel {}, base)).map(Into::into)
}
PostProcessorWrapper::Bert(_) => {
Py::new(py, (PyBertProcessing {}, base)).map(Into::into)
}
Ok(match self.processor.as_ref() {
PostProcessorWrapper::ByteLevel(_) => Py::new(py, (PyByteLevel {}, base))?.into_py(py),
PostProcessorWrapper::Bert(_) => Py::new(py, (PyBertProcessing {}, base))?.into_py(py),
PostProcessorWrapper::Roberta(_) => {
Py::new(py, (PyRobertaProcessing {}, base)).map(Into::into)
Py::new(py, (PyRobertaProcessing {}, base))?.into_py(py)
}
PostProcessorWrapper::Template(_) => {
Py::new(py, (PyTemplateProcessing {}, base)).map(Into::into)
Py::new(py, (PyTemplateProcessing {}, base))?.into_py(py)
}
}
})
}
}
@ -67,7 +63,7 @@ impl PostProcessor for PyPostProcessor {
impl PyPostProcessor {
fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
let data = serde_json::to_string(self.processor.as_ref()).map_err(|e| {
exceptions::Exception::py_err(format!(
exceptions::PyException::new_err(format!(
"Error while attempting to pickle PostProcessor: {}",
e.to_string()
))
@ -79,7 +75,7 @@ impl PyPostProcessor {
match state.extract::<&PyBytes>(py) {
Ok(s) => {
self.processor = serde_json::from_slice(s.as_bytes()).map_err(|e| {
exceptions::Exception::py_err(format!(
exceptions::PyException::new_err(format!(
"Error while attempting to unpickle PostProcessor: {}",
e.to_string()
))
@ -181,11 +177,11 @@ impl FromPyObject<'_> for PySpecialToken {
} else if let Ok(d) = ob.downcast::<PyDict>() {
let id = d
.get_item("id")
.ok_or_else(|| exceptions::ValueError::py_err("`id` must be specified"))?
.ok_or_else(|| exceptions::PyValueError::new_err("`id` must be specified"))?
.extract::<String>()?;
let ids = d
.get_item("ids")
.ok_or_else(|| exceptions::ValueError::py_err("`ids` must be specified"))?
.ok_or_else(|| exceptions::PyValueError::new_err("`ids` must be specified"))?
.extract::<Vec<u32>>()?;
let type_ids = d.get_item("type_ids").map_or_else(
|| Ok(vec![None; ids.len()]),
@ -193,14 +189,14 @@ impl FromPyObject<'_> for PySpecialToken {
)?;
let tokens = d
.get_item("tokens")
.ok_or_else(|| exceptions::ValueError::py_err("`tokens` must be specified"))?
.ok_or_else(|| exceptions::PyValueError::new_err("`tokens` must be specified"))?
.extract::<Vec<String>>()?;
Ok(Self(
ToPyResult(SpecialToken::new(id, ids, type_ids, tokens)).into_py()?,
))
} else {
Err(exceptions::TypeError::py_err(
Err(exceptions::PyTypeError::new_err(
"Expected Union[Tuple[str, int], Tuple[int, str], dict]",
))
}
@ -223,7 +219,7 @@ impl FromPyObject<'_> for PyTemplate {
} else if let Ok(s) = ob.extract::<Vec<&str>>() {
Ok(Self(s.into()))
} else {
Err(exceptions::TypeError::py_err(
Err(exceptions::PyTypeError::new_err(
"Expected Union[str, List[str]]",
))
}
@ -252,7 +248,7 @@ impl PyTemplateProcessing {
if let Some(sp) = special_tokens {
builder.special_tokens(sp);
}
let processor = builder.build().map_err(exceptions::ValueError::py_err)?;
let processor = builder.build().map_err(exceptions::PyValueError::new_err)?;
Ok((
PyTemplateProcessing {},
@ -265,7 +261,7 @@ impl PyTemplateProcessing {
mod test {
use std::sync::Arc;
use pyo3::{AsPyRef, Python};
use pyo3::prelude::*;
use tk::processors::bert::BertProcessing;
use tk::processors::PostProcessorWrapper;