Makes decode and decode_batch work on borrowed content. (#1251)

* Makes `decode` and `decode_batch` work on borrowed content.

* Make `decode_batch` work with borrowed content.

* Fix lint.

* Attempt to map it into Node.

* Second attempt.

* Step by step.

* One more step.

* Fix lint.

* Please ...

* Removing collect.

* Revert "Removing collect."

This reverts commit 2f7ec04dc84df3cc5488625a4fcb492fdc3545e2.

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
Funtowicz Morgan
2023-05-17 11:18:15 +02:00
committed by GitHub
parent cefc41e8ec
commit b4fcc9ce6e
5 changed files with 17 additions and 13 deletions

View File

@ -106,14 +106,17 @@ impl Task for DecodeTask {
.tokenizer
.read()
.unwrap()
.decode(ids.to_vec(), *skip_special_tokens)
.decode(ids.as_slice(), *skip_special_tokens)
.map_err(|e| format!("{}", e))
.map(DecodeOutput::Single),
DecodeTask::Batch(worker, ids, skip_special_tokens) => worker
.tokenizer
.read()
.unwrap()
.decode_batch(ids.to_vec(), *skip_special_tokens)
.decode_batch(
&ids.iter().map(|v| v.as_slice()).collect::<Vec<&[u32]>>(),
*skip_special_tokens,
)
.map_err(|e| format!("{}", e))
.map(DecodeOutput::Batch),
}

View File

@ -1009,7 +1009,7 @@ impl PyTokenizer {
#[pyo3(signature = (ids, skip_special_tokens = true))]
#[pyo3(text_signature = "(self, ids, skip_special_tokens=True)")]
fn decode(&self, ids: Vec<u32>, skip_special_tokens: bool) -> PyResult<String> {
ToPyResult(self.tokenizer.decode(ids, skip_special_tokens)).into()
ToPyResult(self.tokenizer.decode(&ids, skip_special_tokens)).into()
}
/// Decode a batch of ids back to their corresponding string
@ -1032,7 +1032,8 @@ impl PyTokenizer {
skip_special_tokens: bool,
) -> PyResult<Vec<String>> {
py.allow_threads(|| {
ToPyResult(self.tokenizer.decode_batch(sequences, skip_special_tokens)).into()
let slices = sequences.iter().map(|v| &v[..]).collect::<Vec<&[u32]>>();
ToPyResult(self.tokenizer.decode_batch(&slices, skip_special_tokens)).into()
})
}

View File

@ -59,7 +59,7 @@ fn shell(vocab: &str, merges: &str) -> Result<()> {
println!("Offsets:\t{:?}", encoded.get_offsets());
println!(
"Decoded:\t{}",
tokenizer.decode(encoded.get_ids().to_vec(), true).unwrap()
tokenizer.decode(encoded.get_ids(), true).unwrap()
);
println!("Tokenized in {:?}", elapsed);
}

View File

@ -795,12 +795,12 @@ where
}
/// Decode the given ids, back to a String
pub fn decode(&self, ids: Vec<u32>, skip_special_tokens: bool) -> Result<String> {
pub fn decode(&self, ids: &[u32], skip_special_tokens: bool) -> Result<String> {
let tokens = ids
.into_iter()
.iter()
.filter_map(|id| {
self.added_vocabulary
.id_to_token(id, &self.model)
.id_to_token(*id, &self.model)
.filter(|token| {
!skip_special_tokens || !self.added_vocabulary.is_special_token(token)
})
@ -1008,7 +1008,7 @@ where
/// Decode all sentences in parallel
pub fn decode_batch(
&self,
sentences: Vec<Vec<u32>>,
sentences: &[&[u32]],
skip_special_tokens: bool,
) -> Result<Vec<String>>
where

View File

@ -54,7 +54,7 @@ fn load_tokenizer() {
assert_eq!(encodings.get_ids(), ids);
assert_eq!(encodings.get_tokens(), tokens);
let decoded = tokenizer.decode(ids, false).unwrap();
let decoded = tokenizer.decode(&ids, false).unwrap();
assert_eq!(decoded, example);
}
@ -347,7 +347,7 @@ fn pipeline() -> tokenizers::Result<()> {
// [1, 27253, 16, 93, 11, 5097, 5, 7961, 5112, 6218, 0, 35, 2]
let decoded = tokenizer.decode(
vec![1, 27253, 16, 93, 11, 5097, 5, 7961, 5112, 6218, 0, 35, 2],
&[1, 27253, 16, 93, 11, 5097, 5, 7961, 5112, 6218, 0, 35, 2],
true,
)?;
println!("{}", decoded);
@ -435,7 +435,7 @@ fn pipeline_bert() -> tokenizers::Result<()> {
println!("{:?}", output.get_tokens());
// ["[CLS]", "welcome", "to", "the", "[UNK]", "tok", "##eni", "##zer", "##s", "library", ".", "[SEP]"]
let decoded = bert_tokenizer.decode(output.get_ids().to_vec(), true)?;
let decoded = bert_tokenizer.decode(output.get_ids(), true)?;
println!("{}", decoded);
// "welcome to the tok ##eni ##zer ##s library ."
// END bert_test_decoding
@ -451,7 +451,7 @@ fn pipeline_bert() -> tokenizers::Result<()> {
use tokenizers::decoders::wordpiece::WordPiece as WordPieceDecoder;
bert_tokenizer.with_decoder(WordPieceDecoder::default());
let decoded = bert_tokenizer.decode(output.get_ids().to_vec(), true)?;
let decoded = bert_tokenizer.decode(output.get_ids(), true)?;
// "welcome to the tokenizers library."
// END bert_proper_decoding
assert_eq!(decoded, "welcome to the tokenizers library.");