fix pylist (#1673)

* fix pylist

* add comment about why we use PySequence

* style

* fix encode batch fast as well

* Update bindings/python/src/tokenizer.rs

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>

* fix with capacity

* stub :)

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
Arthur
2024-11-05 16:24:23 +01:00
committed by GitHub
parent 0f3a3f957e
commit 5e223ceb48
2 changed files with 30 additions and 28 deletions

View File

@ -859,7 +859,9 @@ class Tokenizer:
def encode_batch(self, input, is_pretokenized=False, add_special_tokens=True): def encode_batch(self, input, is_pretokenized=False, add_special_tokens=True):
""" """
Encode the given batch of inputs. This method accept both raw text sequences Encode the given batch of inputs. This method accept both raw text sequences
as well as already pre-tokenized sequences. as well as already pre-tokenized sequences. The reason we use `PySequence` is
because it allows type checking with zero-cost (according to PyO3) as we don't
have to convert to check.
Example: Example:
Here are some examples of the inputs that are accepted:: Here are some examples of the inputs that are accepted::

View File

@ -995,7 +995,9 @@ impl PyTokenizer {
} }
/// Encode the given batch of inputs. This method accept both raw text sequences /// Encode the given batch of inputs. This method accept both raw text sequences
/// as well as already pre-tokenized sequences. /// as well as already pre-tokenized sequences. The reason we use `PySequence` is
/// because it allows type checking with zero-cost (according to PyO3) as we don't
/// have to convert to check.
/// ///
/// Example: /// Example:
/// Here are some examples of the inputs that are accepted:: /// Here are some examples of the inputs that are accepted::
@ -1030,25 +1032,24 @@ impl PyTokenizer {
fn encode_batch( fn encode_batch(
&self, &self,
py: Python<'_>, py: Python<'_>,
input: Bound<'_, PyList>, input: Bound<'_, PySequence>,
is_pretokenized: bool, is_pretokenized: bool,
add_special_tokens: bool, add_special_tokens: bool,
) -> PyResult<Vec<PyEncoding>> { ) -> PyResult<Vec<PyEncoding>> {
let input: Vec<tk::EncodeInput> = input let mut items = Vec::<tk::EncodeInput>::with_capacity(input.len()?);
.into_iter() for i in 0..input.len()? {
.map(|o| { let item = input.get_item(i)?;
let input: tk::EncodeInput = if is_pretokenized { let item: tk::EncodeInput = if is_pretokenized {
o.extract::<PreTokenizedEncodeInput>()?.into() item.extract::<PreTokenizedEncodeInput>()?.into()
} else { } else {
o.extract::<TextEncodeInput>()?.into() item.extract::<TextEncodeInput>()?.into()
}; };
Ok(input) items.push(item);
}) }
.collect::<PyResult<Vec<tk::EncodeInput>>>()?;
py.allow_threads(|| { py.allow_threads(|| {
ToPyResult( ToPyResult(
self.tokenizer self.tokenizer
.encode_batch_char_offsets(input, add_special_tokens) .encode_batch_char_offsets(items, add_special_tokens)
.map(|encodings| encodings.into_iter().map(|e| e.into()).collect()), .map(|encodings| encodings.into_iter().map(|e| e.into()).collect()),
) )
.into() .into()
@ -1091,25 +1092,24 @@ impl PyTokenizer {
fn encode_batch_fast( fn encode_batch_fast(
&self, &self,
py: Python<'_>, py: Python<'_>,
input: Bound<'_, PyList>, input: Bound<'_, PySequence>,
is_pretokenized: bool, is_pretokenized: bool,
add_special_tokens: bool, add_special_tokens: bool,
) -> PyResult<Vec<PyEncoding>> { ) -> PyResult<Vec<PyEncoding>> {
let input: Vec<tk::EncodeInput> = input let mut items = Vec::<tk::EncodeInput>::with_capacity(input.len()?);
.into_iter() for i in 0..input.len()? {
.map(|o| { let item = input.get_item(i)?;
let input: tk::EncodeInput = if is_pretokenized { let item: tk::EncodeInput = if is_pretokenized {
o.extract::<PreTokenizedEncodeInput>()?.into() item.extract::<PreTokenizedEncodeInput>()?.into()
} else { } else {
o.extract::<TextEncodeInput>()?.into() item.extract::<TextEncodeInput>()?.into()
}; };
Ok(input) items.push(item);
}) }
.collect::<PyResult<Vec<tk::EncodeInput>>>()?;
py.allow_threads(|| { py.allow_threads(|| {
ToPyResult( ToPyResult(
self.tokenizer self.tokenizer
.encode_batch_fast(input, add_special_tokens) .encode_batch_fast(items, add_special_tokens)
.map(|encodings| encodings.into_iter().map(|e| e.into()).collect()), .map(|encodings| encodings.into_iter().map(|e| e.into()).collect()),
) )
.into() .into()