mirror of
https://github.com/mii443/tokenizers.git
synced 2025-12-06 20:58:22 +00:00
Python - Improve encode/encode_batch
This commit is contained in:
@@ -88,66 +88,99 @@ impl PyObjectProtocol for AddedToken {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[pyclass]
|
struct TextInputSequence(tk::InputSequence);
|
||||||
struct InputSequence {
|
impl FromPyObject<'_> for TextInputSequence {
|
||||||
sequence: tk::InputSequence,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl FromPyObject<'_> for InputSequence {
|
|
||||||
fn extract(ob: &PyAny) -> PyResult<Self> {
|
fn extract(ob: &PyAny) -> PyResult<Self> {
|
||||||
let err = exceptions::ValueError::py_err("InputSequence must be Union[str, List[str]]");
|
let err = exceptions::ValueError::py_err("TextInputSequence must be str");
|
||||||
if let Ok(s) = ob.downcast::<PyString>() {
|
if let Ok(s) = ob.downcast::<PyString>() {
|
||||||
let seq: String = s.extract().map_err(|_| err)?;
|
let seq: String = s.extract().map_err(|_| err)?;
|
||||||
Ok(Self {
|
Ok(Self(seq.into()))
|
||||||
sequence: seq.into(),
|
|
||||||
})
|
|
||||||
} else if let Ok(s) = ob.downcast::<PyList>() {
|
|
||||||
let seq = s.extract::<Vec<String>>().map_err(|_| err)?;
|
|
||||||
Ok(Self {
|
|
||||||
sequence: seq.into(),
|
|
||||||
})
|
|
||||||
} else {
|
} else {
|
||||||
Err(err)
|
Err(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
impl From<TextInputSequence> for tk::InputSequence {
|
||||||
impl From<InputSequence> for tk::InputSequence {
|
fn from(s: TextInputSequence) -> Self {
|
||||||
fn from(s: InputSequence) -> Self {
|
s.0
|
||||||
s.sequence
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[pyclass]
|
struct PreTokenizedInputSequence(tk::InputSequence);
|
||||||
struct EncodeInput {
|
impl FromPyObject<'_> for PreTokenizedInputSequence {
|
||||||
input: tk::EncodeInput,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl FromPyObject<'_> for EncodeInput {
|
|
||||||
fn extract(ob: &PyAny) -> PyResult<Self> {
|
fn extract(ob: &PyAny) -> PyResult<Self> {
|
||||||
let err = exceptions::ValueError::py_err(
|
let err = exceptions::ValueError::py_err(
|
||||||
"EncodeInput must be Union[InputSequence, Tuple[InputSequence, InputSequence]]",
|
"PreTokenizedInputSequence must be Union[List[str], Tuple[str]]",
|
||||||
|
);
|
||||||
|
|
||||||
|
if let Ok(s) = ob.downcast::<PyList>() {
|
||||||
|
let seq = s.extract::<Vec<String>>().map_err(|_| err)?;
|
||||||
|
Ok(Self(seq.into()))
|
||||||
|
} else if let Ok(s) = ob.downcast::<PyTuple>() {
|
||||||
|
let seq = s.extract::<Vec<String>>().map_err(|_| err)?;
|
||||||
|
Ok(Self(seq.into()))
|
||||||
|
} else {
|
||||||
|
Err(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl From<PreTokenizedInputSequence> for tk::InputSequence {
|
||||||
|
fn from(s: PreTokenizedInputSequence) -> Self {
|
||||||
|
s.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct TextEncodeInput(tk::EncodeInput);
|
||||||
|
impl FromPyObject<'_> for TextEncodeInput {
|
||||||
|
fn extract(ob: &PyAny) -> PyResult<Self> {
|
||||||
|
let err = exceptions::ValueError::py_err(
|
||||||
|
"TextEncodeInput must be Union[TextInputSequence, Tuple[InputSequence, InputSequence]]",
|
||||||
);
|
);
|
||||||
|
|
||||||
let gil = Python::acquire_gil();
|
let gil = Python::acquire_gil();
|
||||||
let py = gil.python();
|
let py = gil.python();
|
||||||
let obj = ob.to_object(py);
|
let obj = ob.to_object(py);
|
||||||
|
|
||||||
if let Ok(i) = obj.extract::<InputSequence>(py) {
|
if let Ok(i) = obj.extract::<TextInputSequence>(py) {
|
||||||
Ok(Self { input: i.into() })
|
Ok(Self(i.into()))
|
||||||
} else if let Ok((i1, i2)) = obj.extract::<(InputSequence, InputSequence)>(py) {
|
} else if let Ok((i1, i2)) = obj.extract::<(TextInputSequence, TextInputSequence)>(py) {
|
||||||
Ok(Self {
|
Ok(Self((i1, i2).into()))
|
||||||
input: (i1, i2).into(),
|
|
||||||
})
|
|
||||||
} else {
|
} else {
|
||||||
Err(err)
|
Err(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
impl From<TextEncodeInput> for tk::tokenizer::EncodeInput {
|
||||||
|
fn from(i: TextEncodeInput) -> Self {
|
||||||
|
i.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
struct PreTokenizedEncodeInput(tk::EncodeInput);
|
||||||
|
impl FromPyObject<'_> for PreTokenizedEncodeInput {
|
||||||
|
fn extract(ob: &PyAny) -> PyResult<Self> {
|
||||||
|
let err = exceptions::ValueError::py_err(
|
||||||
|
"PreTokenizedEncodeInput must be Union[PreTokenizedInputSequence, \
|
||||||
|
Tuple[PreTokenizedInputSequence, PreTokenizedInputSequence]]",
|
||||||
|
);
|
||||||
|
|
||||||
impl From<EncodeInput> for tk::tokenizer::EncodeInput {
|
let gil = Python::acquire_gil();
|
||||||
fn from(i: EncodeInput) -> Self {
|
let py = gil.python();
|
||||||
i.input
|
let obj = ob.to_object(py);
|
||||||
|
|
||||||
|
if let Ok(i) = obj.extract::<PreTokenizedInputSequence>(py) {
|
||||||
|
Ok(Self(i.into()))
|
||||||
|
} else if let Ok((i1, i2)) =
|
||||||
|
obj.extract::<(PreTokenizedInputSequence, PreTokenizedInputSequence)>(py)
|
||||||
|
{
|
||||||
|
Ok(Self((i1, i2).into()))
|
||||||
|
} else {
|
||||||
|
Err(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl From<PreTokenizedEncodeInput> for tk::tokenizer::EncodeInput {
|
||||||
|
fn from(i: PreTokenizedEncodeInput) -> Self {
|
||||||
|
i.0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -295,11 +328,37 @@ impl Tokenizer {
|
|||||||
|
|
||||||
/// Input can be:
|
/// Input can be:
|
||||||
/// encode("A single sequence")
|
/// encode("A single sequence")
|
||||||
/// encode(("A sequence", "And its pair"))
|
/// encode("A sequence", "And its pair")
|
||||||
/// encode([ "A", "pre", "tokenized", "sequence" ])
|
/// encode([ "A", "pre", "tokenized", "sequence" ], is_pretokenized=True)
|
||||||
/// encode(([ "A", "pre", "tokenized", "sequence" ], "And its pair"))
|
/// encode(
|
||||||
#[args(add_special_tokens = true)]
|
/// [ "A", "pre", "tokenized", "sequence" ], [ "And", "its", "pair" ],
|
||||||
fn encode(&self, input: EncodeInput, add_special_tokens: bool) -> PyResult<Encoding> {
|
/// is_pretokenized=True
|
||||||
|
/// )
|
||||||
|
#[args(pair = "None", is_pretokenized = "false", add_special_tokens = "true")]
|
||||||
|
fn encode(
|
||||||
|
&self,
|
||||||
|
sequence: &PyAny,
|
||||||
|
pair: Option<&PyAny>,
|
||||||
|
is_pretokenized: bool,
|
||||||
|
add_special_tokens: bool,
|
||||||
|
) -> PyResult<Encoding> {
|
||||||
|
let sequence: tk::InputSequence = if is_pretokenized {
|
||||||
|
sequence.extract::<PreTokenizedInputSequence>()?.into()
|
||||||
|
} else {
|
||||||
|
sequence.extract::<TextInputSequence>()?.into()
|
||||||
|
};
|
||||||
|
let input = match pair {
|
||||||
|
Some(pair) => {
|
||||||
|
let pair: tk::InputSequence = if is_pretokenized {
|
||||||
|
pair.extract::<PreTokenizedInputSequence>()?.into()
|
||||||
|
} else {
|
||||||
|
pair.extract::<TextInputSequence>()?.into()
|
||||||
|
};
|
||||||
|
tk::EncodeInput::Dual(sequence, pair)
|
||||||
|
}
|
||||||
|
None => tk::EncodeInput::Single(sequence),
|
||||||
|
};
|
||||||
|
|
||||||
ToPyResult(
|
ToPyResult(
|
||||||
self.tokenizer
|
self.tokenizer
|
||||||
.encode(input, add_special_tokens)
|
.encode(input, add_special_tokens)
|
||||||
@@ -315,12 +374,24 @@ impl Tokenizer {
|
|||||||
/// [ "A", "pre", "tokenized", "sequence" ],
|
/// [ "A", "pre", "tokenized", "sequence" ],
|
||||||
/// ([ "A", "pre", "tokenized", "sequence" ], "And its pair")
|
/// ([ "A", "pre", "tokenized", "sequence" ], "And its pair")
|
||||||
/// ])
|
/// ])
|
||||||
#[args(add_special_tokens = true)]
|
#[args(is_pretokenized = "false", add_special_tokens = "true")]
|
||||||
fn encode_batch(
|
fn encode_batch(
|
||||||
&self,
|
&self,
|
||||||
input: Vec<EncodeInput>,
|
input: Vec<&PyAny>,
|
||||||
|
is_pretokenized: bool,
|
||||||
add_special_tokens: bool,
|
add_special_tokens: bool,
|
||||||
) -> PyResult<Vec<Encoding>> {
|
) -> PyResult<Vec<Encoding>> {
|
||||||
|
let input: Vec<tk::EncodeInput> = input
|
||||||
|
.into_iter()
|
||||||
|
.map(|o| {
|
||||||
|
let input: tk::EncodeInput = if is_pretokenized {
|
||||||
|
o.extract::<PreTokenizedEncodeInput>()?.into()
|
||||||
|
} else {
|
||||||
|
o.extract::<TextEncodeInput>()?.into()
|
||||||
|
};
|
||||||
|
Ok(input)
|
||||||
|
})
|
||||||
|
.collect::<PyResult<Vec<tk::EncodeInput>>>()?;
|
||||||
ToPyResult(
|
ToPyResult(
|
||||||
self.tokenizer
|
self.tokenizer
|
||||||
.encode_batch(input, add_special_tokens)
|
.encode_batch(input, add_special_tokens)
|
||||||
|
|||||||
Reference in New Issue
Block a user