mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Makes decode
and decode_batch
work on borrowed content. (#1251)
* Makes `decode` and `decode_batch` work on borrowed content. * Make `decode_batch` work with borrowed content. * Fix lint. * Attempt to map it into Node. * Second attempt. * Step by step. * One more step. * Fix lint. * Please ... * Removing collect. * Revert "Removing collect." This reverts commit 2f7ec04dc84df3cc5488625a4fcb492fdc3545e2. --------- Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
@ -106,14 +106,17 @@ impl Task for DecodeTask {
|
||||
.tokenizer
|
||||
.read()
|
||||
.unwrap()
|
||||
.decode(ids.to_vec(), *skip_special_tokens)
|
||||
.decode(ids.as_slice(), *skip_special_tokens)
|
||||
.map_err(|e| format!("{}", e))
|
||||
.map(DecodeOutput::Single),
|
||||
DecodeTask::Batch(worker, ids, skip_special_tokens) => worker
|
||||
.tokenizer
|
||||
.read()
|
||||
.unwrap()
|
||||
.decode_batch(ids.to_vec(), *skip_special_tokens)
|
||||
.decode_batch(
|
||||
&ids.iter().map(|v| v.as_slice()).collect::<Vec<&[u32]>>(),
|
||||
*skip_special_tokens,
|
||||
)
|
||||
.map_err(|e| format!("{}", e))
|
||||
.map(DecodeOutput::Batch),
|
||||
}
|
||||
|
@ -1009,7 +1009,7 @@ impl PyTokenizer {
|
||||
#[pyo3(signature = (ids, skip_special_tokens = true))]
|
||||
#[pyo3(text_signature = "(self, ids, skip_special_tokens=True)")]
|
||||
fn decode(&self, ids: Vec<u32>, skip_special_tokens: bool) -> PyResult<String> {
|
||||
ToPyResult(self.tokenizer.decode(ids, skip_special_tokens)).into()
|
||||
ToPyResult(self.tokenizer.decode(&ids, skip_special_tokens)).into()
|
||||
}
|
||||
|
||||
/// Decode a batch of ids back to their corresponding string
|
||||
@ -1032,7 +1032,8 @@ impl PyTokenizer {
|
||||
skip_special_tokens: bool,
|
||||
) -> PyResult<Vec<String>> {
|
||||
py.allow_threads(|| {
|
||||
ToPyResult(self.tokenizer.decode_batch(sequences, skip_special_tokens)).into()
|
||||
let slices = sequences.iter().map(|v| &v[..]).collect::<Vec<&[u32]>>();
|
||||
ToPyResult(self.tokenizer.decode_batch(&slices, skip_special_tokens)).into()
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -59,7 +59,7 @@ fn shell(vocab: &str, merges: &str) -> Result<()> {
|
||||
println!("Offsets:\t{:?}", encoded.get_offsets());
|
||||
println!(
|
||||
"Decoded:\t{}",
|
||||
tokenizer.decode(encoded.get_ids().to_vec(), true).unwrap()
|
||||
tokenizer.decode(encoded.get_ids(), true).unwrap()
|
||||
);
|
||||
println!("Tokenized in {:?}", elapsed);
|
||||
}
|
||||
|
@ -795,12 +795,12 @@ where
|
||||
}
|
||||
|
||||
/// Decode the given ids, back to a String
|
||||
pub fn decode(&self, ids: Vec<u32>, skip_special_tokens: bool) -> Result<String> {
|
||||
pub fn decode(&self, ids: &[u32], skip_special_tokens: bool) -> Result<String> {
|
||||
let tokens = ids
|
||||
.into_iter()
|
||||
.iter()
|
||||
.filter_map(|id| {
|
||||
self.added_vocabulary
|
||||
.id_to_token(id, &self.model)
|
||||
.id_to_token(*id, &self.model)
|
||||
.filter(|token| {
|
||||
!skip_special_tokens || !self.added_vocabulary.is_special_token(token)
|
||||
})
|
||||
@ -1008,7 +1008,7 @@ where
|
||||
/// Decode all sentences in parallel
|
||||
pub fn decode_batch(
|
||||
&self,
|
||||
sentences: Vec<Vec<u32>>,
|
||||
sentences: &[&[u32]],
|
||||
skip_special_tokens: bool,
|
||||
) -> Result<Vec<String>>
|
||||
where
|
||||
|
@ -54,7 +54,7 @@ fn load_tokenizer() {
|
||||
assert_eq!(encodings.get_ids(), ids);
|
||||
assert_eq!(encodings.get_tokens(), tokens);
|
||||
|
||||
let decoded = tokenizer.decode(ids, false).unwrap();
|
||||
let decoded = tokenizer.decode(&ids, false).unwrap();
|
||||
assert_eq!(decoded, example);
|
||||
}
|
||||
|
||||
@ -347,7 +347,7 @@ fn pipeline() -> tokenizers::Result<()> {
|
||||
// [1, 27253, 16, 93, 11, 5097, 5, 7961, 5112, 6218, 0, 35, 2]
|
||||
|
||||
let decoded = tokenizer.decode(
|
||||
vec![1, 27253, 16, 93, 11, 5097, 5, 7961, 5112, 6218, 0, 35, 2],
|
||||
&[1, 27253, 16, 93, 11, 5097, 5, 7961, 5112, 6218, 0, 35, 2],
|
||||
true,
|
||||
)?;
|
||||
println!("{}", decoded);
|
||||
@ -435,7 +435,7 @@ fn pipeline_bert() -> tokenizers::Result<()> {
|
||||
println!("{:?}", output.get_tokens());
|
||||
// ["[CLS]", "welcome", "to", "the", "[UNK]", "tok", "##eni", "##zer", "##s", "library", ".", "[SEP]"]
|
||||
|
||||
let decoded = bert_tokenizer.decode(output.get_ids().to_vec(), true)?;
|
||||
let decoded = bert_tokenizer.decode(output.get_ids(), true)?;
|
||||
println!("{}", decoded);
|
||||
// "welcome to the tok ##eni ##zer ##s library ."
|
||||
// END bert_test_decoding
|
||||
@ -451,7 +451,7 @@ fn pipeline_bert() -> tokenizers::Result<()> {
|
||||
use tokenizers::decoders::wordpiece::WordPiece as WordPieceDecoder;
|
||||
|
||||
bert_tokenizer.with_decoder(WordPieceDecoder::default());
|
||||
let decoded = bert_tokenizer.decode(output.get_ids().to_vec(), true)?;
|
||||
let decoded = bert_tokenizer.decode(output.get_ids(), true)?;
|
||||
// "welcome to the tokenizers library."
|
||||
// END bert_proper_decoding
|
||||
assert_eq!(decoded, "welcome to the tokenizers library.");
|
||||
|
Reference in New Issue
Block a user