Python - Simplify padding interface

This commit is contained in:
Anthony MOI
2019-12-26 14:34:13 -05:00
parent d1e59e09bf
commit 74cc6f6bde

View File

@ -139,20 +139,21 @@ impl Tokenizer {
self.tokenizer.with_truncation(None); self.tokenizer.with_truncation(None);
} }
fn with_padding( #[args(kwargs = "**")]
&mut self, fn with_padding(&mut self, kwargs: Option<&PyDict>) -> PyResult<()> {
size: Option<usize>, let mut direction = PaddingDirection::Right;
direction: &str, let mut pad_id: u32 = 0;
pad_id: u32, let mut pad_type_id: u32 = 0;
pad_type_id: u32, let mut pad_token = String::from("[PAD]");
pad_token: &str, let mut max_length: Option<usize> = None;
) -> PyResult<()> {
let strategy = if let Some(size) = size { if let Some(kwargs) = kwargs {
PaddingStrategy::Fixed(size) for (key, value) in kwargs {
} else { let key: &str = key.extract()?;
PaddingStrategy::BatchLongest match key {
}; "direction" => {
let direction = match direction { let value: &str = value.extract()?;
direction = match value {
"left" => Ok(PaddingDirection::Left), "left" => Ok(PaddingDirection::Left),
"right" => Ok(PaddingDirection::Right), "right" => Ok(PaddingDirection::Right),
other => Err(PyError(format!( other => Err(PyError(format!(
@ -162,6 +163,21 @@ impl Tokenizer {
)) ))
.into_pyerr()), .into_pyerr()),
}?; }?;
}
"pad_id" => pad_id = value.extract()?,
"pad_type_id" => pad_type_id = value.extract()?,
"pad_token" => pad_token = value.extract()?,
"max_length" => max_length = value.extract()?,
_ => println!("Ignored unknown kwarg option {}", key),
}
}
}
let strategy = if let Some(max_length) = max_length {
PaddingStrategy::Fixed(max_length)
} else {
PaddingStrategy::BatchLongest
};
self.tokenizer.with_padding(Some(PaddingParams { self.tokenizer.with_padding(Some(PaddingParams {
strategy, strategy,