Python - Improve and Test EncodeInput extraction

This commit is contained in:
Anthony MOI
2020-08-21 17:41:34 -04:00
committed by Anthony MOI
parent 220e68117d
commit 3d1322f108
2 changed files with 135 additions and 46 deletions

View File

@ -177,7 +177,7 @@ impl<'s> FromPyObject<'s> for TextInputSequence<'s> {
fn extract(ob: &'s PyAny) -> PyResult<Self> {
let err = exceptions::ValueError::py_err("TextInputSequence must be str");
if let Ok(s) = ob.downcast::<PyString>() {
Ok(Self(s.to_string()?.into()))
Ok(Self(s.to_string().map_err(|_| err)?.into()))
} else {
Err(err)
}
@ -276,23 +276,25 @@ impl From<PyArrayStr> for tk::InputSequence<'_> {
struct PreTokenizedInputSequence<'s>(tk::InputSequence<'s>);
impl<'s> FromPyObject<'s> for PreTokenizedInputSequence<'s> {
fn extract(ob: &'s PyAny) -> PyResult<Self> {
let err = exceptions::ValueError::py_err(
"PreTokenizedInputSequence must be Union[List[str], Tuple[str]]",
);
if let Ok(seq) = ob.extract::<PyArrayUnicode>() {
Ok(Self(seq.into()))
} else if let Ok(seq) = ob.extract::<PyArrayStr>() {
Ok(Self(seq.into()))
} else if let Ok(s) = ob.downcast::<PyList>() {
let seq = s.extract::<Vec<&str>>().map_err(|_| err)?;
Ok(Self(seq.into()))
} else if let Ok(s) = ob.downcast::<PyTuple>() {
let seq = s.extract::<Vec<&str>>().map_err(|_| err)?;
Ok(Self(seq.into()))
} else {
Err(err)
return Ok(Self(seq.into()));
}
if let Ok(seq) = ob.extract::<PyArrayStr>() {
return Ok(Self(seq.into()));
}
if let Ok(s) = ob.downcast::<PyList>() {
if let Ok(seq) = s.extract::<Vec<&str>>() {
return Ok(Self(seq.into()));
}
}
if let Ok(s) = ob.downcast::<PyTuple>() {
if let Ok(seq) = s.extract::<Vec<&str>>() {
return Ok(Self(seq.into()));
}
}
Err(exceptions::ValueError::py_err(
"PreTokenizedInputSequence must be Union[List[str], Tuple[str]]",
))
}
}
impl<'s> From<PreTokenizedInputSequence<'s>> for tk::InputSequence<'s> {
@ -304,17 +306,22 @@ impl<'s> From<PreTokenizedInputSequence<'s>> for tk::InputSequence<'s> {
struct TextEncodeInput<'s>(tk::EncodeInput<'s>);
impl<'s> FromPyObject<'s> for TextEncodeInput<'s> {
fn extract(ob: &'s PyAny) -> PyResult<Self> {
let err = exceptions::ValueError::py_err(
"TextEncodeInput must be Union[TextInputSequence, Tuple[InputSequence, InputSequence]]",
);
if let Ok(i) = ob.extract::<TextInputSequence>() {
Ok(Self(i.into()))
} else if let Ok((i1, i2)) = ob.extract::<(TextInputSequence, TextInputSequence)>() {
Ok(Self((i1, i2).into()))
} else {
Err(err)
return Ok(Self(i.into()));
}
if let Ok((i1, i2)) = ob.extract::<(TextInputSequence, TextInputSequence)>() {
return Ok(Self((i1, i2).into()));
}
if let Ok(arr) = ob.extract::<Vec<&PyAny>>() {
if arr.len() == 2 {
let first = arr[0].extract::<TextInputSequence>()?;
let second = arr[1].extract::<TextInputSequence>()?;
return Ok(Self((first, second).into()));
}
}
Err(exceptions::ValueError::py_err(
"TextEncodeInput must be Union[TextInputSequence, Tuple[InputSequence, InputSequence]]",
))
}
}
impl<'s> From<TextEncodeInput<'s>> for tk::tokenizer::EncodeInput<'s> {
@ -325,20 +332,24 @@ impl<'s> From<TextEncodeInput<'s>> for tk::tokenizer::EncodeInput<'s> {
struct PreTokenizedEncodeInput<'s>(tk::EncodeInput<'s>);
impl<'s> FromPyObject<'s> for PreTokenizedEncodeInput<'s> {
fn extract(ob: &'s PyAny) -> PyResult<Self> {
let err = exceptions::ValueError::py_err(
if let Ok(i) = ob.extract::<PreTokenizedInputSequence>() {
return Ok(Self(i.into()));
}
if let Ok((i1, i2)) = ob.extract::<(PreTokenizedInputSequence, PreTokenizedInputSequence)>()
{
return Ok(Self((i1, i2).into()));
}
if let Ok(arr) = ob.extract::<Vec<&PyAny>>() {
if arr.len() == 2 {
let first = arr[0].extract::<PreTokenizedInputSequence>()?;
let second = arr[1].extract::<PreTokenizedInputSequence>()?;
return Ok(Self((first, second).into()));
}
}
Err(exceptions::ValueError::py_err(
"PreTokenizedEncodeInput must be Union[PreTokenizedInputSequence, \
Tuple[PreTokenizedInputSequence, PreTokenizedInputSequence]]",
);
if let Ok(i) = ob.extract::<PreTokenizedInputSequence>() {
Ok(Self(i.into()))
} else if let Ok((i1, i2)) =
ob.extract::<(PreTokenizedInputSequence, PreTokenizedInputSequence)>()
{
Ok(Self((i1, i2).into()))
} else {
Err(err)
}
))
}
}
impl<'s> From<PreTokenizedEncodeInput<'s>> for tk::tokenizer::EncodeInput<'s> {