diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index b7363261..5ffcef1a 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -88,6 +88,69 @@ impl PyObjectProtocol for AddedToken { } } +#[pyclass] +struct InputSequence { + sequence: tk::InputSequence, +} + +impl FromPyObject<'_> for InputSequence { + fn extract(ob: &PyAny) -> PyResult { + let err = exceptions::ValueError::py_err("InputSequence must be Union[str, List[str]]"); + if let Ok(s) = ob.downcast::() { + let seq: String = s.extract().map_err(|_| err)?; + Ok(Self { + sequence: seq.into(), + }) + } else if let Ok(s) = ob.downcast::() { + let seq = s.extract::>().map_err(|_| err)?; + Ok(Self { + sequence: seq.into(), + }) + } else { + Err(err) + } + } +} + +impl From for tk::InputSequence { + fn from(s: InputSequence) -> Self { + s.sequence + } +} + +#[pyclass] +struct EncodeInput { + input: tk::EncodeInput, +} + +impl FromPyObject<'_> for EncodeInput { + fn extract(ob: &PyAny) -> PyResult { + let err = exceptions::ValueError::py_err( + "EncodeInput must be Union[InputSequence, Tuple[InputSequence, InputSequence]]", + ); + + let gil = Python::acquire_gil(); + let py = gil.python(); + let obj = ob.to_object(py); + + if let Ok(i) = obj.extract::(py) { + Ok(Self { input: i.into() }) + } else if let Ok((i1, i2)) = obj.extract::<(InputSequence, InputSequence)>(py) { + Ok(Self { + input: (i1, i2).into(), + }) + } else { + Err(err) + } + } +} + +impl From for tk::tokenizer::EncodeInput { + fn from(i: EncodeInput) -> Self { + i.input + } +} + #[pyclass(dict)] pub struct Tokenizer { tokenizer: tk::tokenizer::Tokenizer, @@ -230,52 +293,37 @@ impl Tokenizer { .into() } + /// Input can be: + /// encode("A single sequence") + /// encode(("A sequence", "And its pair")) + /// encode([ "A", "pre", "tokenized", "sequence" ]) + /// encode(([ "A", "pre", "tokenized", "sequence" ], "And its pair")) #[args(add_special_tokens = true)] - fn encode( - &self, - sentence: &str, - pair: Option<&str>, - add_special_tokens: bool, - ) -> PyResult { + fn encode(&self, input: EncodeInput, add_special_tokens: bool) -> PyResult { ToPyResult( self.tokenizer - .encode( - if let Some(pair) = pair { - tk::tokenizer::EncodeInput::Dual(sentence.to_owned(), pair.to_owned()) - } else { - tk::tokenizer::EncodeInput::Single(sentence.to_owned()) - }, - add_special_tokens, - ) + .encode(input, add_special_tokens) .map(Encoding::new), ) .into() } + /// Input can be: + /// encode_batch([ + /// "A single sequence", + /// ("A tuple with a sequence", "And its pair"), + /// [ "A", "pre", "tokenized", "sequence" ], + /// ([ "A", "pre", "tokenized", "sequence" ], "And its pair") + /// ]) #[args(add_special_tokens = true)] fn encode_batch( &self, - sentences: &PyList, + input: Vec, add_special_tokens: bool, ) -> PyResult> { - let inputs = sentences - .into_iter() - .map(|item| { - if let Ok(s1) = item.extract::() { - Ok(tk::tokenizer::EncodeInput::Single(s1)) - } else if let Ok((s1, s2)) = item.extract::<(String, String)>() { - Ok(tk::tokenizer::EncodeInput::Dual(s1, s2)) - } else { - Err(exceptions::Exception::py_err( - "Input must be a list[str] or list[(str, str)]", - )) - } - }) - .collect::>>()?; - ToPyResult( self.tokenizer - .encode_batch(inputs, add_special_tokens) + .encode_batch(input, add_special_tokens) .map(|encodings| encodings.into_iter().map(Encoding::new).collect()), ) .into()