Python - InputSequence with references when possible

This commit is contained in:
Anthony MOI
2020-08-19 17:16:49 -04:00
committed by Anthony MOI
parent 053cfd836a
commit d919d68889

View File

@@ -171,27 +171,26 @@ impl PyObjectProtocol for PyAddedToken {
} }
} }
struct TextInputSequence(tk::InputSequence); struct TextInputSequence<'s>(tk::InputSequence<'s>);
impl FromPyObject<'_> for TextInputSequence { impl<'s> FromPyObject<'s> for TextInputSequence<'s> {
fn extract(ob: &PyAny) -> PyResult<Self> { fn extract(ob: &'s PyAny) -> PyResult<Self> {
let err = exceptions::ValueError::py_err("TextInputSequence must be str"); let err = exceptions::ValueError::py_err("TextInputSequence must be str");
if let Ok(s) = ob.downcast::<PyString>() { if let Ok(s) = ob.downcast::<PyString>() {
let seq: String = s.extract().map_err(|_| err)?; Ok(Self(s.to_string()?.into()))
Ok(Self(seq.into()))
} else { } else {
Err(err) Err(err)
} }
} }
} }
impl From<TextInputSequence> for tk::InputSequence { impl<'s> From<TextInputSequence<'s>> for tk::InputSequence<'s> {
fn from(s: TextInputSequence) -> Self { fn from(s: TextInputSequence<'s>) -> Self {
s.0 s.0
} }
} }
struct PreTokenizedInputSequence(tk::InputSequence); struct PreTokenizedInputSequence<'s>(tk::InputSequence<'s>);
impl FromPyObject<'_> for PreTokenizedInputSequence { impl<'s> FromPyObject<'s> for PreTokenizedInputSequence<'s> {
fn extract(ob: &PyAny) -> PyResult<Self> { fn extract(ob: &'s PyAny) -> PyResult<Self> {
let err = exceptions::ValueError::py_err( let err = exceptions::ValueError::py_err(
"PreTokenizedInputSequence must be Union[List[str], Tuple[str]]", "PreTokenizedInputSequence must be Union[List[str], Tuple[str]]",
); );
@@ -207,53 +206,45 @@ impl FromPyObject<'_> for PreTokenizedInputSequence {
} }
} }
} }
impl From<PreTokenizedInputSequence> for tk::InputSequence { impl<'s> From<PreTokenizedInputSequence<'s>> for tk::InputSequence<'s> {
fn from(s: PreTokenizedInputSequence) -> Self { fn from(s: PreTokenizedInputSequence<'s>) -> Self {
s.0 s.0
} }
} }
struct TextEncodeInput(tk::EncodeInput); struct TextEncodeInput<'s>(tk::EncodeInput<'s>);
impl FromPyObject<'_> for TextEncodeInput { impl<'s> FromPyObject<'s> for TextEncodeInput<'s> {
fn extract(ob: &PyAny) -> PyResult<Self> { fn extract(ob: &'s PyAny) -> PyResult<Self> {
let err = exceptions::ValueError::py_err( let err = exceptions::ValueError::py_err(
"TextEncodeInput must be Union[TextInputSequence, Tuple[InputSequence, InputSequence]]", "TextEncodeInput must be Union[TextInputSequence, Tuple[InputSequence, InputSequence]]",
); );
let gil = Python::acquire_gil(); if let Ok(i) = ob.extract::<TextInputSequence>() {
let py = gil.python();
let obj = ob.to_object(py);
if let Ok(i) = obj.extract::<TextInputSequence>(py) {
Ok(Self(i.into())) Ok(Self(i.into()))
} else if let Ok((i1, i2)) = obj.extract::<(TextInputSequence, TextInputSequence)>(py) { } else if let Ok((i1, i2)) = ob.extract::<(TextInputSequence, TextInputSequence)>() {
Ok(Self((i1, i2).into())) Ok(Self((i1, i2).into()))
} else { } else {
Err(err) Err(err)
} }
} }
} }
impl From<TextEncodeInput> for tk::tokenizer::EncodeInput { impl<'s> From<TextEncodeInput<'s>> for tk::tokenizer::EncodeInput<'s> {
fn from(i: TextEncodeInput) -> Self { fn from(i: TextEncodeInput<'s>) -> Self {
i.0 i.0
} }
} }
struct PreTokenizedEncodeInput(tk::EncodeInput); struct PreTokenizedEncodeInput<'s>(tk::EncodeInput<'s>);
impl FromPyObject<'_> for PreTokenizedEncodeInput { impl<'s> FromPyObject<'s> for PreTokenizedEncodeInput<'s> {
fn extract(ob: &PyAny) -> PyResult<Self> { fn extract(ob: &'s PyAny) -> PyResult<Self> {
let err = exceptions::ValueError::py_err( let err = exceptions::ValueError::py_err(
"PreTokenizedEncodeInput must be Union[PreTokenizedInputSequence, \ "PreTokenizedEncodeInput must be Union[PreTokenizedInputSequence, \
Tuple[PreTokenizedInputSequence, PreTokenizedInputSequence]]", Tuple[PreTokenizedInputSequence, PreTokenizedInputSequence]]",
); );
let gil = Python::acquire_gil(); if let Ok(i) = ob.extract::<PreTokenizedInputSequence>() {
let py = gil.python();
let obj = ob.to_object(py);
if let Ok(i) = obj.extract::<PreTokenizedInputSequence>(py) {
Ok(Self(i.into())) Ok(Self(i.into()))
} else if let Ok((i1, i2)) = } else if let Ok((i1, i2)) =
obj.extract::<(PreTokenizedInputSequence, PreTokenizedInputSequence)>(py) ob.extract::<(PreTokenizedInputSequence, PreTokenizedInputSequence)>()
{ {
Ok(Self((i1, i2).into())) Ok(Self((i1, i2).into()))
} else { } else {
@@ -261,8 +252,8 @@ impl FromPyObject<'_> for PreTokenizedEncodeInput {
} }
} }
} }
impl From<PreTokenizedEncodeInput> for tk::tokenizer::EncodeInput { impl<'s> From<PreTokenizedEncodeInput<'s>> for tk::tokenizer::EncodeInput<'s> {
fn from(i: PreTokenizedEncodeInput) -> Self { fn from(i: PreTokenizedEncodeInput<'s>) -> Self {
i.0 i.0
} }
} }