From efaa6f589ab062b0e4ac9f45c7f075f0782fa90d Mon Sep 17 00:00:00 2001 From: Anthony MOI Date: Mon, 27 Apr 2020 14:20:56 -0500 Subject: [PATCH] Python - Improve encode/encode_batch --- bindings/python/src/tokenizer.rs | 155 ++++++++++++++++++++++--------- 1 file changed, 113 insertions(+), 42 deletions(-) diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index 5ffcef1a..14ba6782 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -88,66 +88,99 @@ impl PyObjectProtocol for AddedToken { } } -#[pyclass] -struct InputSequence { - sequence: tk::InputSequence, -} - -impl FromPyObject<'_> for InputSequence { +struct TextInputSequence(tk::InputSequence); +impl FromPyObject<'_> for TextInputSequence { fn extract(ob: &PyAny) -> PyResult { - let err = exceptions::ValueError::py_err("InputSequence must be Union[str, List[str]]"); + let err = exceptions::ValueError::py_err("TextInputSequence must be str"); if let Ok(s) = ob.downcast::() { let seq: String = s.extract().map_err(|_| err)?; - Ok(Self { - sequence: seq.into(), - }) - } else if let Ok(s) = ob.downcast::() { - let seq = s.extract::>().map_err(|_| err)?; - Ok(Self { - sequence: seq.into(), - }) + Ok(Self(seq.into())) } else { Err(err) } } } - -impl From for tk::InputSequence { - fn from(s: InputSequence) -> Self { - s.sequence +impl From for tk::InputSequence { + fn from(s: TextInputSequence) -> Self { + s.0 } } -#[pyclass] -struct EncodeInput { - input: tk::EncodeInput, -} - -impl FromPyObject<'_> for EncodeInput { +struct PreTokenizedInputSequence(tk::InputSequence); +impl FromPyObject<'_> for PreTokenizedInputSequence { fn extract(ob: &PyAny) -> PyResult { let err = exceptions::ValueError::py_err( - "EncodeInput must be Union[InputSequence, Tuple[InputSequence, InputSequence]]", + "PreTokenizedInputSequence must be Union[List[str], Tuple[str]]", + ); + + if let Ok(s) = ob.downcast::() { + let seq = s.extract::>().map_err(|_| err)?; + Ok(Self(seq.into())) + } else if let Ok(s) = ob.downcast::() { + let seq = s.extract::>().map_err(|_| err)?; + Ok(Self(seq.into())) + } else { + Err(err) + } + } +} +impl From for tk::InputSequence { + fn from(s: PreTokenizedInputSequence) -> Self { + s.0 + } +} + +struct TextEncodeInput(tk::EncodeInput); +impl FromPyObject<'_> for TextEncodeInput { + fn extract(ob: &PyAny) -> PyResult { + let err = exceptions::ValueError::py_err( + "TextEncodeInput must be Union[TextInputSequence, Tuple[InputSequence, InputSequence]]", ); let gil = Python::acquire_gil(); let py = gil.python(); let obj = ob.to_object(py); - if let Ok(i) = obj.extract::(py) { - Ok(Self { input: i.into() }) - } else if let Ok((i1, i2)) = obj.extract::<(InputSequence, InputSequence)>(py) { - Ok(Self { - input: (i1, i2).into(), - }) + if let Ok(i) = obj.extract::(py) { + Ok(Self(i.into())) + } else if let Ok((i1, i2)) = obj.extract::<(TextInputSequence, TextInputSequence)>(py) { + Ok(Self((i1, i2).into())) } else { Err(err) } } } +impl From for tk::tokenizer::EncodeInput { + fn from(i: TextEncodeInput) -> Self { + i.0 + } +} +struct PreTokenizedEncodeInput(tk::EncodeInput); +impl FromPyObject<'_> for PreTokenizedEncodeInput { + fn extract(ob: &PyAny) -> PyResult { + let err = exceptions::ValueError::py_err( + "PreTokenizedEncodeInput must be Union[PreTokenizedInputSequence, \ + Tuple[PreTokenizedInputSequence, PreTokenizedInputSequence]]", + ); -impl From for tk::tokenizer::EncodeInput { - fn from(i: EncodeInput) -> Self { - i.input + let gil = Python::acquire_gil(); + let py = gil.python(); + let obj = ob.to_object(py); + + if let Ok(i) = obj.extract::(py) { + Ok(Self(i.into())) + } else if let Ok((i1, i2)) = + obj.extract::<(PreTokenizedInputSequence, PreTokenizedInputSequence)>(py) + { + Ok(Self((i1, i2).into())) + } else { + Err(err) + } + } +} +impl From for tk::tokenizer::EncodeInput { + fn from(i: PreTokenizedEncodeInput) -> Self { + i.0 } } @@ -295,11 +328,37 @@ impl Tokenizer { /// Input can be: /// encode("A single sequence") - /// encode(("A sequence", "And its pair")) - /// encode([ "A", "pre", "tokenized", "sequence" ]) - /// encode(([ "A", "pre", "tokenized", "sequence" ], "And its pair")) - #[args(add_special_tokens = true)] - fn encode(&self, input: EncodeInput, add_special_tokens: bool) -> PyResult { + /// encode("A sequence", "And its pair") + /// encode([ "A", "pre", "tokenized", "sequence" ], is_pretokenized=True) + /// encode( + /// [ "A", "pre", "tokenized", "sequence" ], [ "And", "its", "pair" ], + /// is_pretokenized=True + /// ) + #[args(pair = "None", is_pretokenized = "false", add_special_tokens = "true")] + fn encode( + &self, + sequence: &PyAny, + pair: Option<&PyAny>, + is_pretokenized: bool, + add_special_tokens: bool, + ) -> PyResult { + let sequence: tk::InputSequence = if is_pretokenized { + sequence.extract::()?.into() + } else { + sequence.extract::()?.into() + }; + let input = match pair { + Some(pair) => { + let pair: tk::InputSequence = if is_pretokenized { + pair.extract::()?.into() + } else { + pair.extract::()?.into() + }; + tk::EncodeInput::Dual(sequence, pair) + } + None => tk::EncodeInput::Single(sequence), + }; + ToPyResult( self.tokenizer .encode(input, add_special_tokens) @@ -315,12 +374,24 @@ impl Tokenizer { /// [ "A", "pre", "tokenized", "sequence" ], /// ([ "A", "pre", "tokenized", "sequence" ], "And its pair") /// ]) - #[args(add_special_tokens = true)] + #[args(is_pretokenized = "false", add_special_tokens = "true")] fn encode_batch( &self, - input: Vec, + input: Vec<&PyAny>, + is_pretokenized: bool, add_special_tokens: bool, ) -> PyResult> { + let input: Vec = input + .into_iter() + .map(|o| { + let input: tk::EncodeInput = if is_pretokenized { + o.extract::()?.into() + } else { + o.extract::()?.into() + }; + Ok(input) + }) + .collect::>>()?; ToPyResult( self.tokenizer .encode_batch(input, add_special_tokens)