Python - Simplify padding interface

This commit is contained in:
Anthony MOI
2019-12-26 14:34:13 -05:00
parent d1e59e09bf
commit 74cc6f6bde

View File

@ -139,29 +139,45 @@ impl Tokenizer {
self.tokenizer.with_truncation(None);
}
fn with_padding(
&mut self,
size: Option<usize>,
direction: &str,
pad_id: u32,
pad_type_id: u32,
pad_token: &str,
) -> PyResult<()> {
let strategy = if let Some(size) = size {
PaddingStrategy::Fixed(size)
#[args(kwargs = "**")]
fn with_padding(&mut self, kwargs: Option<&PyDict>) -> PyResult<()> {
let mut direction = PaddingDirection::Right;
let mut pad_id: u32 = 0;
let mut pad_type_id: u32 = 0;
let mut pad_token = String::from("[PAD]");
let mut max_length: Option<usize> = None;
if let Some(kwargs) = kwargs {
for (key, value) in kwargs {
let key: &str = key.extract()?;
match key {
"direction" => {
let value: &str = value.extract()?;
direction = match value {
"left" => Ok(PaddingDirection::Left),
"right" => Ok(PaddingDirection::Right),
other => Err(PyError(format!(
"Unknown `direction`: `{}`. Use \
one of `left` or `right`",
other
))
.into_pyerr()),
}?;
}
"pad_id" => pad_id = value.extract()?,
"pad_type_id" => pad_type_id = value.extract()?,
"pad_token" => pad_token = value.extract()?,
"max_length" => max_length = value.extract()?,
_ => println!("Ignored unknown kwarg option {}", key),
}
}
}
let strategy = if let Some(max_length) = max_length {
PaddingStrategy::Fixed(max_length)
} else {
PaddingStrategy::BatchLongest
};
let direction = match direction {
"left" => Ok(PaddingDirection::Left),
"right" => Ok(PaddingDirection::Right),
other => Err(PyError(format!(
"Unknown `direction`: `{}`. Use \
one of `left` or `right`",
other
))
.into_pyerr()),
}?;
self.tokenizer.with_padding(Some(PaddingParams {
strategy,