Update PyO3 (#426)

This commit is contained in:
Anthony MOI
2020-09-22 12:00:20 -04:00
committed by GitHub
parent 8e220dbdd4
commit 940f8bd8fa
13 changed files with 156 additions and 178 deletions

View File

@ -175,9 +175,9 @@ impl PyObjectProtocol for PyAddedToken {
struct TextInputSequence<'s>(tk::InputSequence<'s>);
impl<'s> FromPyObject<'s> for TextInputSequence<'s> {
fn extract(ob: &'s PyAny) -> PyResult<Self> {
let err = exceptions::ValueError::py_err("TextInputSequence must be str");
let err = exceptions::PyTypeError::new_err("TextInputSequence must be str");
if let Ok(s) = ob.downcast::<PyString>() {
Ok(Self(s.to_string().map_err(|_| err)?.into()))
Ok(Self(s.to_string_lossy().into()))
} else {
Err(err)
}
@ -207,7 +207,9 @@ impl FromPyObject<'_> for PyArrayUnicode {
// type_num == 19 => Unicode
if type_num != 19 {
return Err(exceptions::TypeError::py_err("Expected a np.array[str]"));
return Err(exceptions::PyTypeError::new_err(
"Expected a np.array[dtype='U']",
));
}
unsafe {
@ -224,7 +226,7 @@ impl FromPyObject<'_> for PyArrayUnicode {
let py = gil.python();
let obj = PyObject::from_owned_ptr(py, unicode);
let s = obj.cast_as::<PyString>(py)?;
Ok(s.to_string()?.trim_matches(char::from(0)).to_owned())
Ok(s.to_string_lossy().trim_matches(char::from(0)).to_owned())
})
.collect::<PyResult<Vec<_>>>()?;
@ -247,7 +249,9 @@ impl FromPyObject<'_> for PyArrayStr {
let n_elem = array.shape()[0];
if type_num != 17 {
return Err(exceptions::TypeError::py_err("Expected a np.array[str]"));
return Err(exceptions::PyTypeError::new_err(
"Expected a np.array[dtype='O']",
));
}
unsafe {
@ -259,7 +263,7 @@ impl FromPyObject<'_> for PyArrayStr {
let gil = Python::acquire_gil();
let py = gil.python();
let s = obj.cast_as::<PyString>(py)?;
Ok(s.to_string()?.into_owned())
Ok(s.to_string_lossy().into_owned())
})
.collect::<PyResult<Vec<_>>>()?;
@ -292,7 +296,7 @@ impl<'s> FromPyObject<'s> for PreTokenizedInputSequence<'s> {
return Ok(Self(seq.into()));
}
}
Err(exceptions::ValueError::py_err(
Err(exceptions::PyTypeError::new_err(
"PreTokenizedInputSequence must be Union[List[str], Tuple[str]]",
))
}
@ -319,7 +323,7 @@ impl<'s> FromPyObject<'s> for TextEncodeInput<'s> {
return Ok(Self((first, second).into()));
}
}
Err(exceptions::ValueError::py_err(
Err(exceptions::PyTypeError::new_err(
"TextEncodeInput must be Union[TextInputSequence, Tuple[InputSequence, InputSequence]]",
))
}
@ -346,7 +350,7 @@ impl<'s> FromPyObject<'s> for PreTokenizedEncodeInput<'s> {
return Ok(Self((first, second).into()));
}
}
Err(exceptions::ValueError::py_err(
Err(exceptions::PyTypeError::new_err(
"PreTokenizedEncodeInput must be Union[PreTokenizedInputSequence, \
Tuple[PreTokenizedInputSequence, PreTokenizedInputSequence]]",
))
@ -385,9 +389,9 @@ impl PyTokenizer {
fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
let data = serde_json::to_string(&self.tokenizer).map_err(|e| {
exceptions::Exception::py_err(format!(
exceptions::PyException::new_err(format!(
"Error while attempting to pickle Tokenizer: {}",
e.to_string()
e
))
})?;
Ok(PyBytes::new(py, data.as_bytes()).to_object(py))
@ -397,9 +401,9 @@ impl PyTokenizer {
match state.extract::<&PyBytes>(py) {
Ok(s) => {
self.tokenizer = serde_json::from_slice(s.as_bytes()).map_err(|e| {
exceptions::Exception::py_err(format!(
exceptions::PyException::new_err(format!(
"Error while attempting to unpickle Tokenizer: {}",
e.to_string()
e
))
})?;
Ok(())
@ -429,9 +433,9 @@ impl PyTokenizer {
#[staticmethod]
fn from_buffer(buffer: &PyBytes) -> PyResult<Self> {
let tokenizer = serde_json::from_slice(buffer.as_bytes()).map_err(|e| {
exceptions::Exception::py_err(format!(
exceptions::PyValueError::new_err(format!(
"Cannot instantiate Tokenizer from buffer: {}",
e.to_string()
e
))
})?;
Ok(Self { tokenizer })
@ -485,7 +489,7 @@ impl PyTokenizer {
one of `longest_first`, `only_first`, or `only_second`",
value
))
.into_pyerr()),
.into_pyerr::<exceptions::PyValueError>()),
}?
}
_ => println!("Ignored unknown kwarg option {}", key),
@ -533,7 +537,7 @@ impl PyTokenizer {
one of `left` or `right`",
other
))
.into_pyerr()),
.into_pyerr::<exceptions::PyValueError>()),
}?;
}
"pad_to_multiple_of" => {
@ -716,7 +720,7 @@ impl PyTokenizer {
token.is_special_token = false;
Ok(token.get_token())
} else {
Err(exceptions::Exception::py_err(
Err(exceptions::PyTypeError::new_err(
"Input must be a List[Union[str, AddedToken]]",
))
}
@ -736,7 +740,7 @@ impl PyTokenizer {
token.is_special_token = true;
Ok(token.get_token())
} else {
Err(exceptions::Exception::py_err(
Err(exceptions::PyTypeError::new_err(
"Input must be a List[Union[str, AddedToken]]",
))
}
@ -747,10 +751,7 @@ impl PyTokenizer {
}
fn train(&mut self, trainer: &PyTrainer, files: Vec<String>) -> PyResult<()> {
self.tokenizer
.train_and_replace(trainer, files)
.map_err(|e| exceptions::Exception::py_err(format!("{}", e)))?;
Ok(())
ToPyResult(self.tokenizer.train_and_replace(trainer, files)).into()
}
#[args(pair = "None", add_special_tokens = true)]