From b4fcc9ce6e4ad5806e82826f816acfdfdc4fcc67 Mon Sep 17 00:00:00 2001 From: Funtowicz Morgan Date: Wed, 17 May 2023 11:18:15 +0200 Subject: [PATCH] Makes `decode` and `decode_batch` work on borrowed content. (#1251) * Makes `decode` and `decode_batch` work on borrowed content. * Make `decode_batch` work with borrowed content. * Fix lint. * Attempt to map it into Node. * Second attempt. * Step by step. * One more step. * Fix lint. * Please ... * Removing collect. * Revert "Removing collect." This reverts commit 2f7ec04dc84df3cc5488625a4fcb492fdc3545e2. --------- Co-authored-by: Nicolas Patry --- bindings/node/native/src/tasks/tokenizer.rs | 7 +++++-- bindings/python/src/tokenizer.rs | 5 +++-- tokenizers/src/cli.rs | 2 +- tokenizers/src/tokenizer/mod.rs | 8 ++++---- tokenizers/tests/documentation.rs | 8 ++++---- 5 files changed, 17 insertions(+), 13 deletions(-) diff --git a/bindings/node/native/src/tasks/tokenizer.rs b/bindings/node/native/src/tasks/tokenizer.rs index 2a7e3e0a..495ae53a 100644 --- a/bindings/node/native/src/tasks/tokenizer.rs +++ b/bindings/node/native/src/tasks/tokenizer.rs @@ -106,14 +106,17 @@ impl Task for DecodeTask { .tokenizer .read() .unwrap() - .decode(ids.to_vec(), *skip_special_tokens) + .decode(ids.as_slice(), *skip_special_tokens) .map_err(|e| format!("{}", e)) .map(DecodeOutput::Single), DecodeTask::Batch(worker, ids, skip_special_tokens) => worker .tokenizer .read() .unwrap() - .decode_batch(ids.to_vec(), *skip_special_tokens) + .decode_batch( + &ids.iter().map(|v| v.as_slice()).collect::>(), + *skip_special_tokens, + ) .map_err(|e| format!("{}", e)) .map(DecodeOutput::Batch), } diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index 95a954a2..1fe296ed 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -1009,7 +1009,7 @@ impl PyTokenizer { #[pyo3(signature = (ids, skip_special_tokens = true))] #[pyo3(text_signature = "(self, ids, skip_special_tokens=True)")] fn decode(&self, ids: Vec, skip_special_tokens: bool) -> PyResult { - ToPyResult(self.tokenizer.decode(ids, skip_special_tokens)).into() + ToPyResult(self.tokenizer.decode(&ids, skip_special_tokens)).into() } /// Decode a batch of ids back to their corresponding string @@ -1032,7 +1032,8 @@ impl PyTokenizer { skip_special_tokens: bool, ) -> PyResult> { py.allow_threads(|| { - ToPyResult(self.tokenizer.decode_batch(sequences, skip_special_tokens)).into() + let slices = sequences.iter().map(|v| &v[..]).collect::>(); + ToPyResult(self.tokenizer.decode_batch(&slices, skip_special_tokens)).into() }) } diff --git a/tokenizers/src/cli.rs b/tokenizers/src/cli.rs index 6bf523ef..54b82357 100644 --- a/tokenizers/src/cli.rs +++ b/tokenizers/src/cli.rs @@ -59,7 +59,7 @@ fn shell(vocab: &str, merges: &str) -> Result<()> { println!("Offsets:\t{:?}", encoded.get_offsets()); println!( "Decoded:\t{}", - tokenizer.decode(encoded.get_ids().to_vec(), true).unwrap() + tokenizer.decode(encoded.get_ids(), true).unwrap() ); println!("Tokenized in {:?}", elapsed); } diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index a88306f3..01ec187c 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -795,12 +795,12 @@ where } /// Decode the given ids, back to a String - pub fn decode(&self, ids: Vec, skip_special_tokens: bool) -> Result { + pub fn decode(&self, ids: &[u32], skip_special_tokens: bool) -> Result { let tokens = ids - .into_iter() + .iter() .filter_map(|id| { self.added_vocabulary - .id_to_token(id, &self.model) + .id_to_token(*id, &self.model) .filter(|token| { !skip_special_tokens || !self.added_vocabulary.is_special_token(token) }) @@ -1008,7 +1008,7 @@ where /// Decode all sentences in parallel pub fn decode_batch( &self, - sentences: Vec>, + sentences: &[&[u32]], skip_special_tokens: bool, ) -> Result> where diff --git a/tokenizers/tests/documentation.rs b/tokenizers/tests/documentation.rs index 605f8a4b..7cf04deb 100644 --- a/tokenizers/tests/documentation.rs +++ b/tokenizers/tests/documentation.rs @@ -54,7 +54,7 @@ fn load_tokenizer() { assert_eq!(encodings.get_ids(), ids); assert_eq!(encodings.get_tokens(), tokens); - let decoded = tokenizer.decode(ids, false).unwrap(); + let decoded = tokenizer.decode(&ids, false).unwrap(); assert_eq!(decoded, example); } @@ -347,7 +347,7 @@ fn pipeline() -> tokenizers::Result<()> { // [1, 27253, 16, 93, 11, 5097, 5, 7961, 5112, 6218, 0, 35, 2] let decoded = tokenizer.decode( - vec![1, 27253, 16, 93, 11, 5097, 5, 7961, 5112, 6218, 0, 35, 2], + &[1, 27253, 16, 93, 11, 5097, 5, 7961, 5112, 6218, 0, 35, 2], true, )?; println!("{}", decoded); @@ -435,7 +435,7 @@ fn pipeline_bert() -> tokenizers::Result<()> { println!("{:?}", output.get_tokens()); // ["[CLS]", "welcome", "to", "the", "[UNK]", "tok", "##eni", "##zer", "##s", "library", ".", "[SEP]"] - let decoded = bert_tokenizer.decode(output.get_ids().to_vec(), true)?; + let decoded = bert_tokenizer.decode(output.get_ids(), true)?; println!("{}", decoded); // "welcome to the tok ##eni ##zer ##s library ." // END bert_test_decoding @@ -451,7 +451,7 @@ fn pipeline_bert() -> tokenizers::Result<()> { use tokenizers::decoders::wordpiece::WordPiece as WordPieceDecoder; bert_tokenizer.with_decoder(WordPieceDecoder::default()); - let decoded = bert_tokenizer.decode(output.get_ids().to_vec(), true)?; + let decoded = bert_tokenizer.decode(output.get_ids(), true)?; // "welcome to the tokenizers library." // END bert_proper_decoding assert_eq!(decoded, "welcome to the tokenizers library.");