mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Fix encode_batch and encode_batch_fast to accept ndarrays again (#1679)
* Fix encode_batch and encode_batch_fast to accept ndarrays again * Fix clippy --------- Co-authored-by: Dimitris Iliopoulos <diliopoulos@fb.com>
This commit is contained in:
committed by
GitHub
parent
f0c48bd89a
commit
ac34660e44
@ -408,10 +408,10 @@ impl<'s> FromPyObject<'s> for TextEncodeInput<'s> {
|
|||||||
if let Ok((i1, i2)) = ob.extract::<(TextInputSequence, TextInputSequence)>() {
|
if let Ok((i1, i2)) = ob.extract::<(TextInputSequence, TextInputSequence)>() {
|
||||||
return Ok(Self((i1, i2).into()));
|
return Ok(Self((i1, i2).into()));
|
||||||
}
|
}
|
||||||
if let Ok(arr) = ob.downcast::<PyList>() {
|
if let Ok(arr) = ob.extract::<Vec<Bound<PyAny>>>() {
|
||||||
if arr.len() == 2 {
|
if arr.len() == 2 {
|
||||||
let first = arr.get_item(0)?.extract::<TextInputSequence>()?;
|
let first = arr[0].extract::<TextInputSequence>()?;
|
||||||
let second = arr.get_item(1)?.extract::<TextInputSequence>()?;
|
let second = arr[1].extract::<TextInputSequence>()?;
|
||||||
return Ok(Self((first, second).into()));
|
return Ok(Self((first, second).into()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -435,10 +435,10 @@ impl<'s> FromPyObject<'s> for PreTokenizedEncodeInput<'s> {
|
|||||||
{
|
{
|
||||||
return Ok(Self((i1, i2).into()));
|
return Ok(Self((i1, i2).into()));
|
||||||
}
|
}
|
||||||
if let Ok(arr) = ob.downcast::<PyList>() {
|
if let Ok(arr) = ob.extract::<Vec<Bound<PyAny>>>() {
|
||||||
if arr.len() == 2 {
|
if arr.len() == 2 {
|
||||||
let first = arr.get_item(0)?.extract::<PreTokenizedInputSequence>()?;
|
let first = arr[0].extract::<PreTokenizedInputSequence>()?;
|
||||||
let second = arr.get_item(1)?.extract::<PreTokenizedInputSequence>()?;
|
let second = arr[1].extract::<PreTokenizedInputSequence>()?;
|
||||||
return Ok(Self((first, second).into()));
|
return Ok(Self((first, second).into()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1033,13 +1033,12 @@ impl PyTokenizer {
|
|||||||
fn encode_batch(
|
fn encode_batch(
|
||||||
&self,
|
&self,
|
||||||
py: Python<'_>,
|
py: Python<'_>,
|
||||||
input: Bound<'_, PySequence>,
|
input: Vec<Bound<'_, PyAny>>,
|
||||||
is_pretokenized: bool,
|
is_pretokenized: bool,
|
||||||
add_special_tokens: bool,
|
add_special_tokens: bool,
|
||||||
) -> PyResult<Vec<PyEncoding>> {
|
) -> PyResult<Vec<PyEncoding>> {
|
||||||
let mut items = Vec::<tk::EncodeInput>::with_capacity(input.len()?);
|
let mut items = Vec::<tk::EncodeInput>::with_capacity(input.len());
|
||||||
for i in 0..input.len()? {
|
for item in &input {
|
||||||
let item = input.get_item(i)?;
|
|
||||||
let item: tk::EncodeInput = if is_pretokenized {
|
let item: tk::EncodeInput = if is_pretokenized {
|
||||||
item.extract::<PreTokenizedEncodeInput>()?.into()
|
item.extract::<PreTokenizedEncodeInput>()?.into()
|
||||||
} else {
|
} else {
|
||||||
@ -1093,13 +1092,12 @@ impl PyTokenizer {
|
|||||||
fn encode_batch_fast(
|
fn encode_batch_fast(
|
||||||
&self,
|
&self,
|
||||||
py: Python<'_>,
|
py: Python<'_>,
|
||||||
input: Bound<'_, PySequence>,
|
input: Vec<Bound<'_, PyAny>>,
|
||||||
is_pretokenized: bool,
|
is_pretokenized: bool,
|
||||||
add_special_tokens: bool,
|
add_special_tokens: bool,
|
||||||
) -> PyResult<Vec<PyEncoding>> {
|
) -> PyResult<Vec<PyEncoding>> {
|
||||||
let mut items = Vec::<tk::EncodeInput>::with_capacity(input.len()?);
|
let mut items = Vec::<tk::EncodeInput>::with_capacity(input.len());
|
||||||
for i in 0..input.len()? {
|
for item in &input {
|
||||||
let item = input.get_item(i)?;
|
|
||||||
let item: tk::EncodeInput = if is_pretokenized {
|
let item: tk::EncodeInput = if is_pretokenized {
|
||||||
item.extract::<PreTokenizedEncodeInput>()?.into()
|
item.extract::<PreTokenizedEncodeInput>()?.into()
|
||||||
} else {
|
} else {
|
||||||
|
@ -153,8 +153,6 @@ class TestTokenizer:
|
|||||||
assert len(output) == 2
|
assert len(output) == 2
|
||||||
|
|
||||||
def test_encode_formats(self, bert_files):
|
def test_encode_formats(self, bert_files):
|
||||||
print("Broken by the change from std::usize::Max to usixeMax")
|
|
||||||
return 0
|
|
||||||
with pytest.deprecated_call():
|
with pytest.deprecated_call():
|
||||||
tokenizer = BertWordPieceTokenizer(bert_files["vocab"])
|
tokenizer = BertWordPieceTokenizer(bert_files["vocab"])
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user