mirror of
https://github.com/mii443/tokenizers.git
synced 2025-12-05 20:28:22 +00:00
Python - Update bindings for new encode
This commit is contained in:
@@ -88,6 +88,69 @@ impl PyObjectProtocol for AddedToken {
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
struct InputSequence {
|
||||
sequence: tk::InputSequence,
|
||||
}
|
||||
|
||||
impl FromPyObject<'_> for InputSequence {
|
||||
fn extract(ob: &PyAny) -> PyResult<Self> {
|
||||
let err = exceptions::ValueError::py_err("InputSequence must be Union[str, List[str]]");
|
||||
if let Ok(s) = ob.downcast::<PyString>() {
|
||||
let seq: String = s.extract().map_err(|_| err)?;
|
||||
Ok(Self {
|
||||
sequence: seq.into(),
|
||||
})
|
||||
} else if let Ok(s) = ob.downcast::<PyList>() {
|
||||
let seq = s.extract::<Vec<String>>().map_err(|_| err)?;
|
||||
Ok(Self {
|
||||
sequence: seq.into(),
|
||||
})
|
||||
} else {
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<InputSequence> for tk::InputSequence {
|
||||
fn from(s: InputSequence) -> Self {
|
||||
s.sequence
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
struct EncodeInput {
|
||||
input: tk::EncodeInput,
|
||||
}
|
||||
|
||||
impl FromPyObject<'_> for EncodeInput {
|
||||
fn extract(ob: &PyAny) -> PyResult<Self> {
|
||||
let err = exceptions::ValueError::py_err(
|
||||
"EncodeInput must be Union[InputSequence, Tuple[InputSequence, InputSequence]]",
|
||||
);
|
||||
|
||||
let gil = Python::acquire_gil();
|
||||
let py = gil.python();
|
||||
let obj = ob.to_object(py);
|
||||
|
||||
if let Ok(i) = obj.extract::<InputSequence>(py) {
|
||||
Ok(Self { input: i.into() })
|
||||
} else if let Ok((i1, i2)) = obj.extract::<(InputSequence, InputSequence)>(py) {
|
||||
Ok(Self {
|
||||
input: (i1, i2).into(),
|
||||
})
|
||||
} else {
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<EncodeInput> for tk::tokenizer::EncodeInput {
|
||||
fn from(i: EncodeInput) -> Self {
|
||||
i.input
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass(dict)]
|
||||
pub struct Tokenizer {
|
||||
tokenizer: tk::tokenizer::Tokenizer,
|
||||
@@ -230,52 +293,37 @@ impl Tokenizer {
|
||||
.into()
|
||||
}
|
||||
|
||||
/// Input can be:
|
||||
/// encode("A single sequence")
|
||||
/// encode(("A sequence", "And its pair"))
|
||||
/// encode([ "A", "pre", "tokenized", "sequence" ])
|
||||
/// encode(([ "A", "pre", "tokenized", "sequence" ], "And its pair"))
|
||||
#[args(add_special_tokens = true)]
|
||||
fn encode(
|
||||
&self,
|
||||
sentence: &str,
|
||||
pair: Option<&str>,
|
||||
add_special_tokens: bool,
|
||||
) -> PyResult<Encoding> {
|
||||
fn encode(&self, input: EncodeInput, add_special_tokens: bool) -> PyResult<Encoding> {
|
||||
ToPyResult(
|
||||
self.tokenizer
|
||||
.encode(
|
||||
if let Some(pair) = pair {
|
||||
tk::tokenizer::EncodeInput::Dual(sentence.to_owned(), pair.to_owned())
|
||||
} else {
|
||||
tk::tokenizer::EncodeInput::Single(sentence.to_owned())
|
||||
},
|
||||
add_special_tokens,
|
||||
)
|
||||
.encode(input, add_special_tokens)
|
||||
.map(Encoding::new),
|
||||
)
|
||||
.into()
|
||||
}
|
||||
|
||||
/// Input can be:
|
||||
/// encode_batch([
|
||||
/// "A single sequence",
|
||||
/// ("A tuple with a sequence", "And its pair"),
|
||||
/// [ "A", "pre", "tokenized", "sequence" ],
|
||||
/// ([ "A", "pre", "tokenized", "sequence" ], "And its pair")
|
||||
/// ])
|
||||
#[args(add_special_tokens = true)]
|
||||
fn encode_batch(
|
||||
&self,
|
||||
sentences: &PyList,
|
||||
input: Vec<EncodeInput>,
|
||||
add_special_tokens: bool,
|
||||
) -> PyResult<Vec<Encoding>> {
|
||||
let inputs = sentences
|
||||
.into_iter()
|
||||
.map(|item| {
|
||||
if let Ok(s1) = item.extract::<String>() {
|
||||
Ok(tk::tokenizer::EncodeInput::Single(s1))
|
||||
} else if let Ok((s1, s2)) = item.extract::<(String, String)>() {
|
||||
Ok(tk::tokenizer::EncodeInput::Dual(s1, s2))
|
||||
} else {
|
||||
Err(exceptions::Exception::py_err(
|
||||
"Input must be a list[str] or list[(str, str)]",
|
||||
))
|
||||
}
|
||||
})
|
||||
.collect::<PyResult<Vec<_>>>()?;
|
||||
|
||||
ToPyResult(
|
||||
self.tokenizer
|
||||
.encode_batch(inputs, add_special_tokens)
|
||||
.encode_batch(input, add_special_tokens)
|
||||
.map(|encodings| encodings.into_iter().map(Encoding::new).collect()),
|
||||
)
|
||||
.into()
|
||||
|
||||
Reference in New Issue
Block a user