mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-23 16:49:27 +00:00
Makes decode
and decode_batch
work on borrowed content. (#1251)
* Makes `decode` and `decode_batch` work on borrowed content. * Make `decode_batch` work with borrowed content. * Fix lint. * Attempt to map it into Node. * Second attempt. * Step by step. * One more step. * Fix lint. * Please ... * Removing collect. * Revert "Removing collect." This reverts commit 2f7ec04dc84df3cc5488625a4fcb492fdc3545e2. --------- Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
@ -106,14 +106,17 @@ impl Task for DecodeTask {
|
|||||||
.tokenizer
|
.tokenizer
|
||||||
.read()
|
.read()
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.decode(ids.to_vec(), *skip_special_tokens)
|
.decode(ids.as_slice(), *skip_special_tokens)
|
||||||
.map_err(|e| format!("{}", e))
|
.map_err(|e| format!("{}", e))
|
||||||
.map(DecodeOutput::Single),
|
.map(DecodeOutput::Single),
|
||||||
DecodeTask::Batch(worker, ids, skip_special_tokens) => worker
|
DecodeTask::Batch(worker, ids, skip_special_tokens) => worker
|
||||||
.tokenizer
|
.tokenizer
|
||||||
.read()
|
.read()
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.decode_batch(ids.to_vec(), *skip_special_tokens)
|
.decode_batch(
|
||||||
|
&ids.iter().map(|v| v.as_slice()).collect::<Vec<&[u32]>>(),
|
||||||
|
*skip_special_tokens,
|
||||||
|
)
|
||||||
.map_err(|e| format!("{}", e))
|
.map_err(|e| format!("{}", e))
|
||||||
.map(DecodeOutput::Batch),
|
.map(DecodeOutput::Batch),
|
||||||
}
|
}
|
||||||
|
@ -1009,7 +1009,7 @@ impl PyTokenizer {
|
|||||||
#[pyo3(signature = (ids, skip_special_tokens = true))]
|
#[pyo3(signature = (ids, skip_special_tokens = true))]
|
||||||
#[pyo3(text_signature = "(self, ids, skip_special_tokens=True)")]
|
#[pyo3(text_signature = "(self, ids, skip_special_tokens=True)")]
|
||||||
fn decode(&self, ids: Vec<u32>, skip_special_tokens: bool) -> PyResult<String> {
|
fn decode(&self, ids: Vec<u32>, skip_special_tokens: bool) -> PyResult<String> {
|
||||||
ToPyResult(self.tokenizer.decode(ids, skip_special_tokens)).into()
|
ToPyResult(self.tokenizer.decode(&ids, skip_special_tokens)).into()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Decode a batch of ids back to their corresponding string
|
/// Decode a batch of ids back to their corresponding string
|
||||||
@ -1032,7 +1032,8 @@ impl PyTokenizer {
|
|||||||
skip_special_tokens: bool,
|
skip_special_tokens: bool,
|
||||||
) -> PyResult<Vec<String>> {
|
) -> PyResult<Vec<String>> {
|
||||||
py.allow_threads(|| {
|
py.allow_threads(|| {
|
||||||
ToPyResult(self.tokenizer.decode_batch(sequences, skip_special_tokens)).into()
|
let slices = sequences.iter().map(|v| &v[..]).collect::<Vec<&[u32]>>();
|
||||||
|
ToPyResult(self.tokenizer.decode_batch(&slices, skip_special_tokens)).into()
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -59,7 +59,7 @@ fn shell(vocab: &str, merges: &str) -> Result<()> {
|
|||||||
println!("Offsets:\t{:?}", encoded.get_offsets());
|
println!("Offsets:\t{:?}", encoded.get_offsets());
|
||||||
println!(
|
println!(
|
||||||
"Decoded:\t{}",
|
"Decoded:\t{}",
|
||||||
tokenizer.decode(encoded.get_ids().to_vec(), true).unwrap()
|
tokenizer.decode(encoded.get_ids(), true).unwrap()
|
||||||
);
|
);
|
||||||
println!("Tokenized in {:?}", elapsed);
|
println!("Tokenized in {:?}", elapsed);
|
||||||
}
|
}
|
||||||
|
@ -795,12 +795,12 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Decode the given ids, back to a String
|
/// Decode the given ids, back to a String
|
||||||
pub fn decode(&self, ids: Vec<u32>, skip_special_tokens: bool) -> Result<String> {
|
pub fn decode(&self, ids: &[u32], skip_special_tokens: bool) -> Result<String> {
|
||||||
let tokens = ids
|
let tokens = ids
|
||||||
.into_iter()
|
.iter()
|
||||||
.filter_map(|id| {
|
.filter_map(|id| {
|
||||||
self.added_vocabulary
|
self.added_vocabulary
|
||||||
.id_to_token(id, &self.model)
|
.id_to_token(*id, &self.model)
|
||||||
.filter(|token| {
|
.filter(|token| {
|
||||||
!skip_special_tokens || !self.added_vocabulary.is_special_token(token)
|
!skip_special_tokens || !self.added_vocabulary.is_special_token(token)
|
||||||
})
|
})
|
||||||
@ -1008,7 +1008,7 @@ where
|
|||||||
/// Decode all sentences in parallel
|
/// Decode all sentences in parallel
|
||||||
pub fn decode_batch(
|
pub fn decode_batch(
|
||||||
&self,
|
&self,
|
||||||
sentences: Vec<Vec<u32>>,
|
sentences: &[&[u32]],
|
||||||
skip_special_tokens: bool,
|
skip_special_tokens: bool,
|
||||||
) -> Result<Vec<String>>
|
) -> Result<Vec<String>>
|
||||||
where
|
where
|
||||||
|
@ -54,7 +54,7 @@ fn load_tokenizer() {
|
|||||||
assert_eq!(encodings.get_ids(), ids);
|
assert_eq!(encodings.get_ids(), ids);
|
||||||
assert_eq!(encodings.get_tokens(), tokens);
|
assert_eq!(encodings.get_tokens(), tokens);
|
||||||
|
|
||||||
let decoded = tokenizer.decode(ids, false).unwrap();
|
let decoded = tokenizer.decode(&ids, false).unwrap();
|
||||||
assert_eq!(decoded, example);
|
assert_eq!(decoded, example);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -347,7 +347,7 @@ fn pipeline() -> tokenizers::Result<()> {
|
|||||||
// [1, 27253, 16, 93, 11, 5097, 5, 7961, 5112, 6218, 0, 35, 2]
|
// [1, 27253, 16, 93, 11, 5097, 5, 7961, 5112, 6218, 0, 35, 2]
|
||||||
|
|
||||||
let decoded = tokenizer.decode(
|
let decoded = tokenizer.decode(
|
||||||
vec![1, 27253, 16, 93, 11, 5097, 5, 7961, 5112, 6218, 0, 35, 2],
|
&[1, 27253, 16, 93, 11, 5097, 5, 7961, 5112, 6218, 0, 35, 2],
|
||||||
true,
|
true,
|
||||||
)?;
|
)?;
|
||||||
println!("{}", decoded);
|
println!("{}", decoded);
|
||||||
@ -435,7 +435,7 @@ fn pipeline_bert() -> tokenizers::Result<()> {
|
|||||||
println!("{:?}", output.get_tokens());
|
println!("{:?}", output.get_tokens());
|
||||||
// ["[CLS]", "welcome", "to", "the", "[UNK]", "tok", "##eni", "##zer", "##s", "library", ".", "[SEP]"]
|
// ["[CLS]", "welcome", "to", "the", "[UNK]", "tok", "##eni", "##zer", "##s", "library", ".", "[SEP]"]
|
||||||
|
|
||||||
let decoded = bert_tokenizer.decode(output.get_ids().to_vec(), true)?;
|
let decoded = bert_tokenizer.decode(output.get_ids(), true)?;
|
||||||
println!("{}", decoded);
|
println!("{}", decoded);
|
||||||
// "welcome to the tok ##eni ##zer ##s library ."
|
// "welcome to the tok ##eni ##zer ##s library ."
|
||||||
// END bert_test_decoding
|
// END bert_test_decoding
|
||||||
@ -451,7 +451,7 @@ fn pipeline_bert() -> tokenizers::Result<()> {
|
|||||||
use tokenizers::decoders::wordpiece::WordPiece as WordPieceDecoder;
|
use tokenizers::decoders::wordpiece::WordPiece as WordPieceDecoder;
|
||||||
|
|
||||||
bert_tokenizer.with_decoder(WordPieceDecoder::default());
|
bert_tokenizer.with_decoder(WordPieceDecoder::default());
|
||||||
let decoded = bert_tokenizer.decode(output.get_ids().to_vec(), true)?;
|
let decoded = bert_tokenizer.decode(output.get_ids(), true)?;
|
||||||
// "welcome to the tokenizers library."
|
// "welcome to the tokenizers library."
|
||||||
// END bert_proper_decoding
|
// END bert_proper_decoding
|
||||||
assert_eq!(decoded, "welcome to the tokenizers library.");
|
assert_eq!(decoded, "welcome to the tokenizers library.");
|
||||||
|
Reference in New Issue
Block a user