Python - Update bindings for new encode

This commit is contained in:
Anthony MOI
2020-04-24 21:28:02 -04:00
parent 993c1c80a8
commit 835f08ab02

View File

@@ -88,6 +88,69 @@ impl PyObjectProtocol for AddedToken {
}
}
#[pyclass]
struct InputSequence {
sequence: tk::InputSequence,
}
impl FromPyObject<'_> for InputSequence {
fn extract(ob: &PyAny) -> PyResult<Self> {
let err = exceptions::ValueError::py_err("InputSequence must be Union[str, List[str]]");
if let Ok(s) = ob.downcast::<PyString>() {
let seq: String = s.extract().map_err(|_| err)?;
Ok(Self {
sequence: seq.into(),
})
} else if let Ok(s) = ob.downcast::<PyList>() {
let seq = s.extract::<Vec<String>>().map_err(|_| err)?;
Ok(Self {
sequence: seq.into(),
})
} else {
Err(err)
}
}
}
impl From<InputSequence> for tk::InputSequence {
fn from(s: InputSequence) -> Self {
s.sequence
}
}
#[pyclass]
struct EncodeInput {
input: tk::EncodeInput,
}
impl FromPyObject<'_> for EncodeInput {
fn extract(ob: &PyAny) -> PyResult<Self> {
let err = exceptions::ValueError::py_err(
"EncodeInput must be Union[InputSequence, Tuple[InputSequence, InputSequence]]",
);
let gil = Python::acquire_gil();
let py = gil.python();
let obj = ob.to_object(py);
if let Ok(i) = obj.extract::<InputSequence>(py) {
Ok(Self { input: i.into() })
} else if let Ok((i1, i2)) = obj.extract::<(InputSequence, InputSequence)>(py) {
Ok(Self {
input: (i1, i2).into(),
})
} else {
Err(err)
}
}
}
impl From<EncodeInput> for tk::tokenizer::EncodeInput {
fn from(i: EncodeInput) -> Self {
i.input
}
}
#[pyclass(dict)]
pub struct Tokenizer {
tokenizer: tk::tokenizer::Tokenizer,
@@ -230,52 +293,37 @@ impl Tokenizer {
.into()
}
/// Input can be:
/// encode("A single sequence")
/// encode(("A sequence", "And its pair"))
/// encode([ "A", "pre", "tokenized", "sequence" ])
/// encode(([ "A", "pre", "tokenized", "sequence" ], "And its pair"))
#[args(add_special_tokens = true)]
fn encode(
&self,
sentence: &str,
pair: Option<&str>,
add_special_tokens: bool,
) -> PyResult<Encoding> {
fn encode(&self, input: EncodeInput, add_special_tokens: bool) -> PyResult<Encoding> {
ToPyResult(
self.tokenizer
.encode(
if let Some(pair) = pair {
tk::tokenizer::EncodeInput::Dual(sentence.to_owned(), pair.to_owned())
} else {
tk::tokenizer::EncodeInput::Single(sentence.to_owned())
},
add_special_tokens,
)
.encode(input, add_special_tokens)
.map(Encoding::new),
)
.into()
}
/// Input can be:
/// encode_batch([
/// "A single sequence",
/// ("A tuple with a sequence", "And its pair"),
/// [ "A", "pre", "tokenized", "sequence" ],
/// ([ "A", "pre", "tokenized", "sequence" ], "And its pair")
/// ])
#[args(add_special_tokens = true)]
fn encode_batch(
&self,
sentences: &PyList,
input: Vec<EncodeInput>,
add_special_tokens: bool,
) -> PyResult<Vec<Encoding>> {
let inputs = sentences
.into_iter()
.map(|item| {
if let Ok(s1) = item.extract::<String>() {
Ok(tk::tokenizer::EncodeInput::Single(s1))
} else if let Ok((s1, s2)) = item.extract::<(String, String)>() {
Ok(tk::tokenizer::EncodeInput::Dual(s1, s2))
} else {
Err(exceptions::Exception::py_err(
"Input must be a list[str] or list[(str, str)]",
))
}
})
.collect::<PyResult<Vec<_>>>()?;
ToPyResult(
self.tokenizer
.encode_batch(inputs, add_special_tokens)
.encode_batch(input, add_special_tokens)
.map(|encodings| encodings.into_iter().map(Encoding::new).collect()),
)
.into()