mirror of
https://github.com/mii443/tokenizers.git
synced 2025-12-06 12:48:18 +00:00
Python - Improve encode/encode_batch
This commit is contained in:
@@ -88,66 +88,99 @@ impl PyObjectProtocol for AddedToken {
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
struct InputSequence {
|
||||
sequence: tk::InputSequence,
|
||||
}
|
||||
|
||||
impl FromPyObject<'_> for InputSequence {
|
||||
struct TextInputSequence(tk::InputSequence);
|
||||
impl FromPyObject<'_> for TextInputSequence {
|
||||
fn extract(ob: &PyAny) -> PyResult<Self> {
|
||||
let err = exceptions::ValueError::py_err("InputSequence must be Union[str, List[str]]");
|
||||
let err = exceptions::ValueError::py_err("TextInputSequence must be str");
|
||||
if let Ok(s) = ob.downcast::<PyString>() {
|
||||
let seq: String = s.extract().map_err(|_| err)?;
|
||||
Ok(Self {
|
||||
sequence: seq.into(),
|
||||
})
|
||||
} else if let Ok(s) = ob.downcast::<PyList>() {
|
||||
let seq = s.extract::<Vec<String>>().map_err(|_| err)?;
|
||||
Ok(Self {
|
||||
sequence: seq.into(),
|
||||
})
|
||||
Ok(Self(seq.into()))
|
||||
} else {
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<InputSequence> for tk::InputSequence {
|
||||
fn from(s: InputSequence) -> Self {
|
||||
s.sequence
|
||||
impl From<TextInputSequence> for tk::InputSequence {
|
||||
fn from(s: TextInputSequence) -> Self {
|
||||
s.0
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
struct EncodeInput {
|
||||
input: tk::EncodeInput,
|
||||
}
|
||||
|
||||
impl FromPyObject<'_> for EncodeInput {
|
||||
struct PreTokenizedInputSequence(tk::InputSequence);
|
||||
impl FromPyObject<'_> for PreTokenizedInputSequence {
|
||||
fn extract(ob: &PyAny) -> PyResult<Self> {
|
||||
let err = exceptions::ValueError::py_err(
|
||||
"EncodeInput must be Union[InputSequence, Tuple[InputSequence, InputSequence]]",
|
||||
"PreTokenizedInputSequence must be Union[List[str], Tuple[str]]",
|
||||
);
|
||||
|
||||
if let Ok(s) = ob.downcast::<PyList>() {
|
||||
let seq = s.extract::<Vec<String>>().map_err(|_| err)?;
|
||||
Ok(Self(seq.into()))
|
||||
} else if let Ok(s) = ob.downcast::<PyTuple>() {
|
||||
let seq = s.extract::<Vec<String>>().map_err(|_| err)?;
|
||||
Ok(Self(seq.into()))
|
||||
} else {
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
impl From<PreTokenizedInputSequence> for tk::InputSequence {
|
||||
fn from(s: PreTokenizedInputSequence) -> Self {
|
||||
s.0
|
||||
}
|
||||
}
|
||||
|
||||
struct TextEncodeInput(tk::EncodeInput);
|
||||
impl FromPyObject<'_> for TextEncodeInput {
|
||||
fn extract(ob: &PyAny) -> PyResult<Self> {
|
||||
let err = exceptions::ValueError::py_err(
|
||||
"TextEncodeInput must be Union[TextInputSequence, Tuple[InputSequence, InputSequence]]",
|
||||
);
|
||||
|
||||
let gil = Python::acquire_gil();
|
||||
let py = gil.python();
|
||||
let obj = ob.to_object(py);
|
||||
|
||||
if let Ok(i) = obj.extract::<InputSequence>(py) {
|
||||
Ok(Self { input: i.into() })
|
||||
} else if let Ok((i1, i2)) = obj.extract::<(InputSequence, InputSequence)>(py) {
|
||||
Ok(Self {
|
||||
input: (i1, i2).into(),
|
||||
})
|
||||
if let Ok(i) = obj.extract::<TextInputSequence>(py) {
|
||||
Ok(Self(i.into()))
|
||||
} else if let Ok((i1, i2)) = obj.extract::<(TextInputSequence, TextInputSequence)>(py) {
|
||||
Ok(Self((i1, i2).into()))
|
||||
} else {
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
impl From<TextEncodeInput> for tk::tokenizer::EncodeInput {
|
||||
fn from(i: TextEncodeInput) -> Self {
|
||||
i.0
|
||||
}
|
||||
}
|
||||
struct PreTokenizedEncodeInput(tk::EncodeInput);
|
||||
impl FromPyObject<'_> for PreTokenizedEncodeInput {
|
||||
fn extract(ob: &PyAny) -> PyResult<Self> {
|
||||
let err = exceptions::ValueError::py_err(
|
||||
"PreTokenizedEncodeInput must be Union[PreTokenizedInputSequence, \
|
||||
Tuple[PreTokenizedInputSequence, PreTokenizedInputSequence]]",
|
||||
);
|
||||
|
||||
impl From<EncodeInput> for tk::tokenizer::EncodeInput {
|
||||
fn from(i: EncodeInput) -> Self {
|
||||
i.input
|
||||
let gil = Python::acquire_gil();
|
||||
let py = gil.python();
|
||||
let obj = ob.to_object(py);
|
||||
|
||||
if let Ok(i) = obj.extract::<PreTokenizedInputSequence>(py) {
|
||||
Ok(Self(i.into()))
|
||||
} else if let Ok((i1, i2)) =
|
||||
obj.extract::<(PreTokenizedInputSequence, PreTokenizedInputSequence)>(py)
|
||||
{
|
||||
Ok(Self((i1, i2).into()))
|
||||
} else {
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
impl From<PreTokenizedEncodeInput> for tk::tokenizer::EncodeInput {
|
||||
fn from(i: PreTokenizedEncodeInput) -> Self {
|
||||
i.0
|
||||
}
|
||||
}
|
||||
|
||||
@@ -295,11 +328,37 @@ impl Tokenizer {
|
||||
|
||||
/// Input can be:
|
||||
/// encode("A single sequence")
|
||||
/// encode(("A sequence", "And its pair"))
|
||||
/// encode([ "A", "pre", "tokenized", "sequence" ])
|
||||
/// encode(([ "A", "pre", "tokenized", "sequence" ], "And its pair"))
|
||||
#[args(add_special_tokens = true)]
|
||||
fn encode(&self, input: EncodeInput, add_special_tokens: bool) -> PyResult<Encoding> {
|
||||
/// encode("A sequence", "And its pair")
|
||||
/// encode([ "A", "pre", "tokenized", "sequence" ], is_pretokenized=True)
|
||||
/// encode(
|
||||
/// [ "A", "pre", "tokenized", "sequence" ], [ "And", "its", "pair" ],
|
||||
/// is_pretokenized=True
|
||||
/// )
|
||||
#[args(pair = "None", is_pretokenized = "false", add_special_tokens = "true")]
|
||||
fn encode(
|
||||
&self,
|
||||
sequence: &PyAny,
|
||||
pair: Option<&PyAny>,
|
||||
is_pretokenized: bool,
|
||||
add_special_tokens: bool,
|
||||
) -> PyResult<Encoding> {
|
||||
let sequence: tk::InputSequence = if is_pretokenized {
|
||||
sequence.extract::<PreTokenizedInputSequence>()?.into()
|
||||
} else {
|
||||
sequence.extract::<TextInputSequence>()?.into()
|
||||
};
|
||||
let input = match pair {
|
||||
Some(pair) => {
|
||||
let pair: tk::InputSequence = if is_pretokenized {
|
||||
pair.extract::<PreTokenizedInputSequence>()?.into()
|
||||
} else {
|
||||
pair.extract::<TextInputSequence>()?.into()
|
||||
};
|
||||
tk::EncodeInput::Dual(sequence, pair)
|
||||
}
|
||||
None => tk::EncodeInput::Single(sequence),
|
||||
};
|
||||
|
||||
ToPyResult(
|
||||
self.tokenizer
|
||||
.encode(input, add_special_tokens)
|
||||
@@ -315,12 +374,24 @@ impl Tokenizer {
|
||||
/// [ "A", "pre", "tokenized", "sequence" ],
|
||||
/// ([ "A", "pre", "tokenized", "sequence" ], "And its pair")
|
||||
/// ])
|
||||
#[args(add_special_tokens = true)]
|
||||
#[args(is_pretokenized = "false", add_special_tokens = "true")]
|
||||
fn encode_batch(
|
||||
&self,
|
||||
input: Vec<EncodeInput>,
|
||||
input: Vec<&PyAny>,
|
||||
is_pretokenized: bool,
|
||||
add_special_tokens: bool,
|
||||
) -> PyResult<Vec<Encoding>> {
|
||||
let input: Vec<tk::EncodeInput> = input
|
||||
.into_iter()
|
||||
.map(|o| {
|
||||
let input: tk::EncodeInput = if is_pretokenized {
|
||||
o.extract::<PreTokenizedEncodeInput>()?.into()
|
||||
} else {
|
||||
o.extract::<TextEncodeInput>()?.into()
|
||||
};
|
||||
Ok(input)
|
||||
})
|
||||
.collect::<PyResult<Vec<tk::EncodeInput>>>()?;
|
||||
ToPyResult(
|
||||
self.tokenizer
|
||||
.encode_batch(input, add_special_tokens)
|
||||
|
||||
Reference in New Issue
Block a user