Python - Improve decode/decode_batch API

This commit is contained in:
Anthony MOI
2020-01-06 16:39:36 -05:00
parent 1a083a6e6f
commit b7d0acc562

View File

@@ -244,15 +244,33 @@ impl Tokenizer {
.into() .into()
} }
fn decode(&self, ids: Vec<u32>, skip_special_tokens: bool) -> PyResult<String> { #[args(kwargs = "**")]
fn decode(&self, ids: Vec<u32>, kwargs: Option<&PyDict>) -> PyResult<String> {
let mut skip_special_tokens = true;
if let Some(kwargs) = kwargs {
if let Some(skip) = kwargs.get_item("skip_special_tokens") {
skip_special_tokens = skip.extract()?;
}
}
ToPyResult(self.tokenizer.decode(ids, skip_special_tokens)).into() ToPyResult(self.tokenizer.decode(ids, skip_special_tokens)).into()
} }
#[args(kwargs = "**")]
fn decode_batch( fn decode_batch(
&self, &self,
sentences: Vec<Vec<u32>>, sentences: Vec<Vec<u32>>,
skip_special_tokens: bool, kwargs: Option<&PyDict>,
) -> PyResult<Vec<String>> { ) -> PyResult<Vec<String>> {
let mut skip_special_tokens = true;
if let Some(kwargs) = kwargs {
if let Some(skip) = kwargs.get_item("skip_special_tokens") {
skip_special_tokens = skip.extract()?;
}
}
ToPyResult(self.tokenizer.decode_batch(sentences, skip_special_tokens)).into() ToPyResult(self.tokenizer.decode_batch(sentences, skip_special_tokens)).into()
} }