Fix encode_batch and encode_batch_fast to accept ndarrays again (#1679)

* Fix encode_batch and encode_batch_fast to accept ndarrays again

* Fix clippy

---------

Co-authored-by: Dimitris Iliopoulos <diliopoulos@fb.com>
This commit is contained in:
Dimitris Iliopoulos
2024-11-21 05:55:11 -05:00
committed by GitHub
parent f0c48bd89a
commit ac34660e44
2 changed files with 12 additions and 16 deletions

View File

@ -408,10 +408,10 @@ impl<'s> FromPyObject<'s> for TextEncodeInput<'s> {
if let Ok((i1, i2)) = ob.extract::<(TextInputSequence, TextInputSequence)>() { if let Ok((i1, i2)) = ob.extract::<(TextInputSequence, TextInputSequence)>() {
return Ok(Self((i1, i2).into())); return Ok(Self((i1, i2).into()));
} }
if let Ok(arr) = ob.downcast::<PyList>() { if let Ok(arr) = ob.extract::<Vec<Bound<PyAny>>>() {
if arr.len() == 2 { if arr.len() == 2 {
let first = arr.get_item(0)?.extract::<TextInputSequence>()?; let first = arr[0].extract::<TextInputSequence>()?;
let second = arr.get_item(1)?.extract::<TextInputSequence>()?; let second = arr[1].extract::<TextInputSequence>()?;
return Ok(Self((first, second).into())); return Ok(Self((first, second).into()));
} }
} }
@ -435,10 +435,10 @@ impl<'s> FromPyObject<'s> for PreTokenizedEncodeInput<'s> {
{ {
return Ok(Self((i1, i2).into())); return Ok(Self((i1, i2).into()));
} }
if let Ok(arr) = ob.downcast::<PyList>() { if let Ok(arr) = ob.extract::<Vec<Bound<PyAny>>>() {
if arr.len() == 2 { if arr.len() == 2 {
let first = arr.get_item(0)?.extract::<PreTokenizedInputSequence>()?; let first = arr[0].extract::<PreTokenizedInputSequence>()?;
let second = arr.get_item(1)?.extract::<PreTokenizedInputSequence>()?; let second = arr[1].extract::<PreTokenizedInputSequence>()?;
return Ok(Self((first, second).into())); return Ok(Self((first, second).into()));
} }
} }
@ -1033,13 +1033,12 @@ impl PyTokenizer {
fn encode_batch( fn encode_batch(
&self, &self,
py: Python<'_>, py: Python<'_>,
input: Bound<'_, PySequence>, input: Vec<Bound<'_, PyAny>>,
is_pretokenized: bool, is_pretokenized: bool,
add_special_tokens: bool, add_special_tokens: bool,
) -> PyResult<Vec<PyEncoding>> { ) -> PyResult<Vec<PyEncoding>> {
let mut items = Vec::<tk::EncodeInput>::with_capacity(input.len()?); let mut items = Vec::<tk::EncodeInput>::with_capacity(input.len());
for i in 0..input.len()? { for item in &input {
let item = input.get_item(i)?;
let item: tk::EncodeInput = if is_pretokenized { let item: tk::EncodeInput = if is_pretokenized {
item.extract::<PreTokenizedEncodeInput>()?.into() item.extract::<PreTokenizedEncodeInput>()?.into()
} else { } else {
@ -1093,13 +1092,12 @@ impl PyTokenizer {
fn encode_batch_fast( fn encode_batch_fast(
&self, &self,
py: Python<'_>, py: Python<'_>,
input: Bound<'_, PySequence>, input: Vec<Bound<'_, PyAny>>,
is_pretokenized: bool, is_pretokenized: bool,
add_special_tokens: bool, add_special_tokens: bool,
) -> PyResult<Vec<PyEncoding>> { ) -> PyResult<Vec<PyEncoding>> {
let mut items = Vec::<tk::EncodeInput>::with_capacity(input.len()?); let mut items = Vec::<tk::EncodeInput>::with_capacity(input.len());
for i in 0..input.len()? { for item in &input {
let item = input.get_item(i)?;
let item: tk::EncodeInput = if is_pretokenized { let item: tk::EncodeInput = if is_pretokenized {
item.extract::<PreTokenizedEncodeInput>()?.into() item.extract::<PreTokenizedEncodeInput>()?.into()
} else { } else {

View File

@ -153,8 +153,6 @@ class TestTokenizer:
assert len(output) == 2 assert len(output) == 2
def test_encode_formats(self, bert_files): def test_encode_formats(self, bert_files):
print("Broken by the change from std::usize::Max to usixeMax")
return 0
with pytest.deprecated_call(): with pytest.deprecated_call():
tokenizer = BertWordPieceTokenizer(bert_files["vocab"]) tokenizer = BertWordPieceTokenizer(bert_files["vocab"])