Makes decode and decode_batch work on borrowed content. (#1251)

* Makes `decode` and `decode_batch` work on borrowed content.

* Make `decode_batch` work with borrowed content.

* Fix lint.

* Attempt to map it into Node.

* Second attempt.

* Step by step.

* One more step.

* Fix lint.

* Please ...

* Removing collect.

* Revert "Removing collect."

This reverts commit 2f7ec04dc84df3cc5488625a4fcb492fdc3545e2.

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
Funtowicz Morgan
2023-05-17 11:18:15 +02:00
committed by GitHub
parent cefc41e8ec
commit b4fcc9ce6e
5 changed files with 17 additions and 13 deletions

View File

@ -106,14 +106,17 @@ impl Task for DecodeTask {
.tokenizer .tokenizer
.read() .read()
.unwrap() .unwrap()
.decode(ids.to_vec(), *skip_special_tokens) .decode(ids.as_slice(), *skip_special_tokens)
.map_err(|e| format!("{}", e)) .map_err(|e| format!("{}", e))
.map(DecodeOutput::Single), .map(DecodeOutput::Single),
DecodeTask::Batch(worker, ids, skip_special_tokens) => worker DecodeTask::Batch(worker, ids, skip_special_tokens) => worker
.tokenizer .tokenizer
.read() .read()
.unwrap() .unwrap()
.decode_batch(ids.to_vec(), *skip_special_tokens) .decode_batch(
&ids.iter().map(|v| v.as_slice()).collect::<Vec<&[u32]>>(),
*skip_special_tokens,
)
.map_err(|e| format!("{}", e)) .map_err(|e| format!("{}", e))
.map(DecodeOutput::Batch), .map(DecodeOutput::Batch),
} }

View File

@ -1009,7 +1009,7 @@ impl PyTokenizer {
#[pyo3(signature = (ids, skip_special_tokens = true))] #[pyo3(signature = (ids, skip_special_tokens = true))]
#[pyo3(text_signature = "(self, ids, skip_special_tokens=True)")] #[pyo3(text_signature = "(self, ids, skip_special_tokens=True)")]
fn decode(&self, ids: Vec<u32>, skip_special_tokens: bool) -> PyResult<String> { fn decode(&self, ids: Vec<u32>, skip_special_tokens: bool) -> PyResult<String> {
ToPyResult(self.tokenizer.decode(ids, skip_special_tokens)).into() ToPyResult(self.tokenizer.decode(&ids, skip_special_tokens)).into()
} }
/// Decode a batch of ids back to their corresponding string /// Decode a batch of ids back to their corresponding string
@ -1032,7 +1032,8 @@ impl PyTokenizer {
skip_special_tokens: bool, skip_special_tokens: bool,
) -> PyResult<Vec<String>> { ) -> PyResult<Vec<String>> {
py.allow_threads(|| { py.allow_threads(|| {
ToPyResult(self.tokenizer.decode_batch(sequences, skip_special_tokens)).into() let slices = sequences.iter().map(|v| &v[..]).collect::<Vec<&[u32]>>();
ToPyResult(self.tokenizer.decode_batch(&slices, skip_special_tokens)).into()
}) })
} }

View File

@ -59,7 +59,7 @@ fn shell(vocab: &str, merges: &str) -> Result<()> {
println!("Offsets:\t{:?}", encoded.get_offsets()); println!("Offsets:\t{:?}", encoded.get_offsets());
println!( println!(
"Decoded:\t{}", "Decoded:\t{}",
tokenizer.decode(encoded.get_ids().to_vec(), true).unwrap() tokenizer.decode(encoded.get_ids(), true).unwrap()
); );
println!("Tokenized in {:?}", elapsed); println!("Tokenized in {:?}", elapsed);
} }

View File

@ -795,12 +795,12 @@ where
} }
/// Decode the given ids, back to a String /// Decode the given ids, back to a String
pub fn decode(&self, ids: Vec<u32>, skip_special_tokens: bool) -> Result<String> { pub fn decode(&self, ids: &[u32], skip_special_tokens: bool) -> Result<String> {
let tokens = ids let tokens = ids
.into_iter() .iter()
.filter_map(|id| { .filter_map(|id| {
self.added_vocabulary self.added_vocabulary
.id_to_token(id, &self.model) .id_to_token(*id, &self.model)
.filter(|token| { .filter(|token| {
!skip_special_tokens || !self.added_vocabulary.is_special_token(token) !skip_special_tokens || !self.added_vocabulary.is_special_token(token)
}) })
@ -1008,7 +1008,7 @@ where
/// Decode all sentences in parallel /// Decode all sentences in parallel
pub fn decode_batch( pub fn decode_batch(
&self, &self,
sentences: Vec<Vec<u32>>, sentences: &[&[u32]],
skip_special_tokens: bool, skip_special_tokens: bool,
) -> Result<Vec<String>> ) -> Result<Vec<String>>
where where

View File

@ -54,7 +54,7 @@ fn load_tokenizer() {
assert_eq!(encodings.get_ids(), ids); assert_eq!(encodings.get_ids(), ids);
assert_eq!(encodings.get_tokens(), tokens); assert_eq!(encodings.get_tokens(), tokens);
let decoded = tokenizer.decode(ids, false).unwrap(); let decoded = tokenizer.decode(&ids, false).unwrap();
assert_eq!(decoded, example); assert_eq!(decoded, example);
} }
@ -347,7 +347,7 @@ fn pipeline() -> tokenizers::Result<()> {
// [1, 27253, 16, 93, 11, 5097, 5, 7961, 5112, 6218, 0, 35, 2] // [1, 27253, 16, 93, 11, 5097, 5, 7961, 5112, 6218, 0, 35, 2]
let decoded = tokenizer.decode( let decoded = tokenizer.decode(
vec![1, 27253, 16, 93, 11, 5097, 5, 7961, 5112, 6218, 0, 35, 2], &[1, 27253, 16, 93, 11, 5097, 5, 7961, 5112, 6218, 0, 35, 2],
true, true,
)?; )?;
println!("{}", decoded); println!("{}", decoded);
@ -435,7 +435,7 @@ fn pipeline_bert() -> tokenizers::Result<()> {
println!("{:?}", output.get_tokens()); println!("{:?}", output.get_tokens());
// ["[CLS]", "welcome", "to", "the", "[UNK]", "tok", "##eni", "##zer", "##s", "library", ".", "[SEP]"] // ["[CLS]", "welcome", "to", "the", "[UNK]", "tok", "##eni", "##zer", "##s", "library", ".", "[SEP]"]
let decoded = bert_tokenizer.decode(output.get_ids().to_vec(), true)?; let decoded = bert_tokenizer.decode(output.get_ids(), true)?;
println!("{}", decoded); println!("{}", decoded);
// "welcome to the tok ##eni ##zer ##s library ." // "welcome to the tok ##eni ##zer ##s library ."
// END bert_test_decoding // END bert_test_decoding
@ -451,7 +451,7 @@ fn pipeline_bert() -> tokenizers::Result<()> {
use tokenizers::decoders::wordpiece::WordPiece as WordPieceDecoder; use tokenizers::decoders::wordpiece::WordPiece as WordPieceDecoder;
bert_tokenizer.with_decoder(WordPieceDecoder::default()); bert_tokenizer.with_decoder(WordPieceDecoder::default());
let decoded = bert_tokenizer.decode(output.get_ids().to_vec(), true)?; let decoded = bert_tokenizer.decode(output.get_ids(), true)?;
// "welcome to the tokenizers library." // "welcome to the tokenizers library."
// END bert_proper_decoding // END bert_proper_decoding
assert_eq!(decoded, "welcome to the tokenizers library."); assert_eq!(decoded, "welcome to the tokenizers library.");