Python - Improve encode/encode_batch

This commit is contained in:
Anthony MOI
2020-04-27 14:20:56 -05:00
parent 36e3c28a23
commit efaa6f589a

View File

@@ -88,66 +88,99 @@ impl PyObjectProtocol for AddedToken {
}
}
#[pyclass]
struct InputSequence {
sequence: tk::InputSequence,
}
impl FromPyObject<'_> for InputSequence {
struct TextInputSequence(tk::InputSequence);
impl FromPyObject<'_> for TextInputSequence {
fn extract(ob: &PyAny) -> PyResult<Self> {
let err = exceptions::ValueError::py_err("InputSequence must be Union[str, List[str]]");
let err = exceptions::ValueError::py_err("TextInputSequence must be str");
if let Ok(s) = ob.downcast::<PyString>() {
let seq: String = s.extract().map_err(|_| err)?;
Ok(Self {
sequence: seq.into(),
})
} else if let Ok(s) = ob.downcast::<PyList>() {
let seq = s.extract::<Vec<String>>().map_err(|_| err)?;
Ok(Self {
sequence: seq.into(),
})
Ok(Self(seq.into()))
} else {
Err(err)
}
}
}
impl From<InputSequence> for tk::InputSequence {
fn from(s: InputSequence) -> Self {
s.sequence
impl From<TextInputSequence> for tk::InputSequence {
fn from(s: TextInputSequence) -> Self {
s.0
}
}
#[pyclass]
struct EncodeInput {
input: tk::EncodeInput,
}
impl FromPyObject<'_> for EncodeInput {
struct PreTokenizedInputSequence(tk::InputSequence);
impl FromPyObject<'_> for PreTokenizedInputSequence {
fn extract(ob: &PyAny) -> PyResult<Self> {
let err = exceptions::ValueError::py_err(
"EncodeInput must be Union[InputSequence, Tuple[InputSequence, InputSequence]]",
"PreTokenizedInputSequence must be Union[List[str], Tuple[str]]",
);
if let Ok(s) = ob.downcast::<PyList>() {
let seq = s.extract::<Vec<String>>().map_err(|_| err)?;
Ok(Self(seq.into()))
} else if let Ok(s) = ob.downcast::<PyTuple>() {
let seq = s.extract::<Vec<String>>().map_err(|_| err)?;
Ok(Self(seq.into()))
} else {
Err(err)
}
}
}
impl From<PreTokenizedInputSequence> for tk::InputSequence {
fn from(s: PreTokenizedInputSequence) -> Self {
s.0
}
}
struct TextEncodeInput(tk::EncodeInput);
impl FromPyObject<'_> for TextEncodeInput {
fn extract(ob: &PyAny) -> PyResult<Self> {
let err = exceptions::ValueError::py_err(
"TextEncodeInput must be Union[TextInputSequence, Tuple[InputSequence, InputSequence]]",
);
let gil = Python::acquire_gil();
let py = gil.python();
let obj = ob.to_object(py);
if let Ok(i) = obj.extract::<InputSequence>(py) {
Ok(Self { input: i.into() })
} else if let Ok((i1, i2)) = obj.extract::<(InputSequence, InputSequence)>(py) {
Ok(Self {
input: (i1, i2).into(),
})
if let Ok(i) = obj.extract::<TextInputSequence>(py) {
Ok(Self(i.into()))
} else if let Ok((i1, i2)) = obj.extract::<(TextInputSequence, TextInputSequence)>(py) {
Ok(Self((i1, i2).into()))
} else {
Err(err)
}
}
}
impl From<TextEncodeInput> for tk::tokenizer::EncodeInput {
fn from(i: TextEncodeInput) -> Self {
i.0
}
}
struct PreTokenizedEncodeInput(tk::EncodeInput);
impl FromPyObject<'_> for PreTokenizedEncodeInput {
fn extract(ob: &PyAny) -> PyResult<Self> {
let err = exceptions::ValueError::py_err(
"PreTokenizedEncodeInput must be Union[PreTokenizedInputSequence, \
Tuple[PreTokenizedInputSequence, PreTokenizedInputSequence]]",
);
impl From<EncodeInput> for tk::tokenizer::EncodeInput {
fn from(i: EncodeInput) -> Self {
i.input
let gil = Python::acquire_gil();
let py = gil.python();
let obj = ob.to_object(py);
if let Ok(i) = obj.extract::<PreTokenizedInputSequence>(py) {
Ok(Self(i.into()))
} else if let Ok((i1, i2)) =
obj.extract::<(PreTokenizedInputSequence, PreTokenizedInputSequence)>(py)
{
Ok(Self((i1, i2).into()))
} else {
Err(err)
}
}
}
impl From<PreTokenizedEncodeInput> for tk::tokenizer::EncodeInput {
fn from(i: PreTokenizedEncodeInput) -> Self {
i.0
}
}
@@ -295,11 +328,37 @@ impl Tokenizer {
/// Input can be:
/// encode("A single sequence")
/// encode(("A sequence", "And its pair"))
/// encode([ "A", "pre", "tokenized", "sequence" ])
/// encode(([ "A", "pre", "tokenized", "sequence" ], "And its pair"))
#[args(add_special_tokens = true)]
fn encode(&self, input: EncodeInput, add_special_tokens: bool) -> PyResult<Encoding> {
/// encode("A sequence", "And its pair")
/// encode([ "A", "pre", "tokenized", "sequence" ], is_pretokenized=True)
/// encode(
/// [ "A", "pre", "tokenized", "sequence" ], [ "And", "its", "pair" ],
/// is_pretokenized=True
/// )
#[args(pair = "None", is_pretokenized = "false", add_special_tokens = "true")]
fn encode(
&self,
sequence: &PyAny,
pair: Option<&PyAny>,
is_pretokenized: bool,
add_special_tokens: bool,
) -> PyResult<Encoding> {
let sequence: tk::InputSequence = if is_pretokenized {
sequence.extract::<PreTokenizedInputSequence>()?.into()
} else {
sequence.extract::<TextInputSequence>()?.into()
};
let input = match pair {
Some(pair) => {
let pair: tk::InputSequence = if is_pretokenized {
pair.extract::<PreTokenizedInputSequence>()?.into()
} else {
pair.extract::<TextInputSequence>()?.into()
};
tk::EncodeInput::Dual(sequence, pair)
}
None => tk::EncodeInput::Single(sequence),
};
ToPyResult(
self.tokenizer
.encode(input, add_special_tokens)
@@ -315,12 +374,24 @@ impl Tokenizer {
/// [ "A", "pre", "tokenized", "sequence" ],
/// ([ "A", "pre", "tokenized", "sequence" ], "And its pair")
/// ])
#[args(add_special_tokens = true)]
#[args(is_pretokenized = "false", add_special_tokens = "true")]
fn encode_batch(
&self,
input: Vec<EncodeInput>,
input: Vec<&PyAny>,
is_pretokenized: bool,
add_special_tokens: bool,
) -> PyResult<Vec<Encoding>> {
let input: Vec<tk::EncodeInput> = input
.into_iter()
.map(|o| {
let input: tk::EncodeInput = if is_pretokenized {
o.extract::<PreTokenizedEncodeInput>()?.into()
} else {
o.extract::<TextEncodeInput>()?.into()
};
Ok(input)
})
.collect::<PyResult<Vec<tk::EncodeInput>>>()?;
ToPyResult(
self.tokenizer
.encode_batch(input, add_special_tokens)