mirror of
https://github.com/mii443/tokenizers.git
synced 2025-12-05 20:28:22 +00:00
Python - Update bindings for new encode
This commit is contained in:
@@ -88,6 +88,69 @@ impl PyObjectProtocol for AddedToken {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[pyclass]
|
||||||
|
struct InputSequence {
|
||||||
|
sequence: tk::InputSequence,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FromPyObject<'_> for InputSequence {
|
||||||
|
fn extract(ob: &PyAny) -> PyResult<Self> {
|
||||||
|
let err = exceptions::ValueError::py_err("InputSequence must be Union[str, List[str]]");
|
||||||
|
if let Ok(s) = ob.downcast::<PyString>() {
|
||||||
|
let seq: String = s.extract().map_err(|_| err)?;
|
||||||
|
Ok(Self {
|
||||||
|
sequence: seq.into(),
|
||||||
|
})
|
||||||
|
} else if let Ok(s) = ob.downcast::<PyList>() {
|
||||||
|
let seq = s.extract::<Vec<String>>().map_err(|_| err)?;
|
||||||
|
Ok(Self {
|
||||||
|
sequence: seq.into(),
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
Err(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<InputSequence> for tk::InputSequence {
|
||||||
|
fn from(s: InputSequence) -> Self {
|
||||||
|
s.sequence
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pyclass]
|
||||||
|
struct EncodeInput {
|
||||||
|
input: tk::EncodeInput,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FromPyObject<'_> for EncodeInput {
|
||||||
|
fn extract(ob: &PyAny) -> PyResult<Self> {
|
||||||
|
let err = exceptions::ValueError::py_err(
|
||||||
|
"EncodeInput must be Union[InputSequence, Tuple[InputSequence, InputSequence]]",
|
||||||
|
);
|
||||||
|
|
||||||
|
let gil = Python::acquire_gil();
|
||||||
|
let py = gil.python();
|
||||||
|
let obj = ob.to_object(py);
|
||||||
|
|
||||||
|
if let Ok(i) = obj.extract::<InputSequence>(py) {
|
||||||
|
Ok(Self { input: i.into() })
|
||||||
|
} else if let Ok((i1, i2)) = obj.extract::<(InputSequence, InputSequence)>(py) {
|
||||||
|
Ok(Self {
|
||||||
|
input: (i1, i2).into(),
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
Err(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<EncodeInput> for tk::tokenizer::EncodeInput {
|
||||||
|
fn from(i: EncodeInput) -> Self {
|
||||||
|
i.input
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[pyclass(dict)]
|
#[pyclass(dict)]
|
||||||
pub struct Tokenizer {
|
pub struct Tokenizer {
|
||||||
tokenizer: tk::tokenizer::Tokenizer,
|
tokenizer: tk::tokenizer::Tokenizer,
|
||||||
@@ -230,52 +293,37 @@ impl Tokenizer {
|
|||||||
.into()
|
.into()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Input can be:
|
||||||
|
/// encode("A single sequence")
|
||||||
|
/// encode(("A sequence", "And its pair"))
|
||||||
|
/// encode([ "A", "pre", "tokenized", "sequence" ])
|
||||||
|
/// encode(([ "A", "pre", "tokenized", "sequence" ], "And its pair"))
|
||||||
#[args(add_special_tokens = true)]
|
#[args(add_special_tokens = true)]
|
||||||
fn encode(
|
fn encode(&self, input: EncodeInput, add_special_tokens: bool) -> PyResult<Encoding> {
|
||||||
&self,
|
|
||||||
sentence: &str,
|
|
||||||
pair: Option<&str>,
|
|
||||||
add_special_tokens: bool,
|
|
||||||
) -> PyResult<Encoding> {
|
|
||||||
ToPyResult(
|
ToPyResult(
|
||||||
self.tokenizer
|
self.tokenizer
|
||||||
.encode(
|
.encode(input, add_special_tokens)
|
||||||
if let Some(pair) = pair {
|
|
||||||
tk::tokenizer::EncodeInput::Dual(sentence.to_owned(), pair.to_owned())
|
|
||||||
} else {
|
|
||||||
tk::tokenizer::EncodeInput::Single(sentence.to_owned())
|
|
||||||
},
|
|
||||||
add_special_tokens,
|
|
||||||
)
|
|
||||||
.map(Encoding::new),
|
.map(Encoding::new),
|
||||||
)
|
)
|
||||||
.into()
|
.into()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Input can be:
|
||||||
|
/// encode_batch([
|
||||||
|
/// "A single sequence",
|
||||||
|
/// ("A tuple with a sequence", "And its pair"),
|
||||||
|
/// [ "A", "pre", "tokenized", "sequence" ],
|
||||||
|
/// ([ "A", "pre", "tokenized", "sequence" ], "And its pair")
|
||||||
|
/// ])
|
||||||
#[args(add_special_tokens = true)]
|
#[args(add_special_tokens = true)]
|
||||||
fn encode_batch(
|
fn encode_batch(
|
||||||
&self,
|
&self,
|
||||||
sentences: &PyList,
|
input: Vec<EncodeInput>,
|
||||||
add_special_tokens: bool,
|
add_special_tokens: bool,
|
||||||
) -> PyResult<Vec<Encoding>> {
|
) -> PyResult<Vec<Encoding>> {
|
||||||
let inputs = sentences
|
|
||||||
.into_iter()
|
|
||||||
.map(|item| {
|
|
||||||
if let Ok(s1) = item.extract::<String>() {
|
|
||||||
Ok(tk::tokenizer::EncodeInput::Single(s1))
|
|
||||||
} else if let Ok((s1, s2)) = item.extract::<(String, String)>() {
|
|
||||||
Ok(tk::tokenizer::EncodeInput::Dual(s1, s2))
|
|
||||||
} else {
|
|
||||||
Err(exceptions::Exception::py_err(
|
|
||||||
"Input must be a list[str] or list[(str, str)]",
|
|
||||||
))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.collect::<PyResult<Vec<_>>>()?;
|
|
||||||
|
|
||||||
ToPyResult(
|
ToPyResult(
|
||||||
self.tokenizer
|
self.tokenizer
|
||||||
.encode_batch(inputs, add_special_tokens)
|
.encode_batch(input, add_special_tokens)
|
||||||
.map(|encodings| encodings.into_iter().map(Encoding::new).collect()),
|
.map(|encodings| encodings.into_iter().map(Encoding::new).collect()),
|
||||||
)
|
)
|
||||||
.into()
|
.into()
|
||||||
|
|||||||
Reference in New Issue
Block a user