mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Option to skip special tokens while decoding
This commit is contained in:
@ -192,12 +192,16 @@ impl Tokenizer {
|
|||||||
.into()
|
.into()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn decode(&self, ids: Vec<u32>) -> PyResult<String> {
|
fn decode(&self, ids: Vec<u32>, skip_special_tokens: bool) -> PyResult<String> {
|
||||||
ToPyResult(self.tokenizer.decode(ids)).into()
|
ToPyResult(self.tokenizer.decode(ids, skip_special_tokens)).into()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn decode_batch(&self, sentences: Vec<Vec<u32>>) -> PyResult<Vec<String>> {
|
fn decode_batch(
|
||||||
ToPyResult(self.tokenizer.decode_batch(sentences)).into()
|
&self,
|
||||||
|
sentences: Vec<Vec<u32>>,
|
||||||
|
skip_special_tokens: bool,
|
||||||
|
) -> PyResult<Vec<String>> {
|
||||||
|
ToPyResult(self.tokenizer.decode_batch(sentences, skip_special_tokens)).into()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn token_to_id(&self, token: &str) -> Option<u32> {
|
fn token_to_id(&self, token: &str) -> Option<u32> {
|
||||||
|
@ -53,7 +53,7 @@ fn shell(matches: &ArgMatches) -> Result<()> {
|
|||||||
println!("Offsets:\t{:?}", encoded.get_offsets());
|
println!("Offsets:\t{:?}", encoded.get_offsets());
|
||||||
println!(
|
println!(
|
||||||
"Decoded:\t{}",
|
"Decoded:\t{}",
|
||||||
tokenizer.decode(encoded.get_ids().to_vec()).unwrap()
|
tokenizer.decode(encoded.get_ids().to_vec(), true).unwrap()
|
||||||
);
|
);
|
||||||
println!("Tokenized in {:?}", elapsed);
|
println!("Tokenized in {:?}", elapsed);
|
||||||
}
|
}
|
||||||
|
@ -342,15 +342,19 @@ impl Tokenizer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Decode the given ids, back to a String
|
/// Decode the given ids, back to a String
|
||||||
pub fn decode(&self, ids: Vec<u32>) -> Result<String> {
|
pub fn decode(&self, ids: Vec<u32>, skip_special_tokens: bool) -> Result<String> {
|
||||||
let tokens = ids
|
let tokens = ids
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|id| {
|
.map(|id| {
|
||||||
if let Some(token) = self.added_tokens_r.get(&id) {
|
let token = if let Some(token) = self.added_tokens_r.get(&id) {
|
||||||
Some(token.content.to_owned())
|
Some(token.content.to_owned())
|
||||||
} else {
|
} else {
|
||||||
self.model.id_to_token(id)
|
self.model.id_to_token(id)
|
||||||
}
|
};
|
||||||
|
|
||||||
|
token.filter(|token| {
|
||||||
|
!skip_special_tokens || !self.special_tokens.contains_key(token)
|
||||||
|
})
|
||||||
})
|
})
|
||||||
.filter(|token| token.is_some())
|
.filter(|token| token.is_some())
|
||||||
.map(|id| id.unwrap())
|
.map(|id| id.unwrap())
|
||||||
@ -364,10 +368,14 @@ impl Tokenizer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Decode all sentences in parallel
|
/// Decode all sentences in parallel
|
||||||
pub fn decode_batch(&self, sentences: Vec<Vec<u32>>) -> Result<Vec<String>> {
|
pub fn decode_batch(
|
||||||
|
&self,
|
||||||
|
sentences: Vec<Vec<u32>>,
|
||||||
|
skip_special_tokens: bool,
|
||||||
|
) -> Result<Vec<String>> {
|
||||||
sentences
|
sentences
|
||||||
.into_par_iter()
|
.into_par_iter()
|
||||||
.map(|sentence| self.decode(sentence))
|
.map(|sentence| self.decode(sentence, skip_special_tokens))
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user