Python - Expose pad and truncate on Encoding

This commit is contained in:
Anthony MOI
2019-12-30 12:56:07 -05:00
parent 8ddb2de64e
commit 4677a09626

View File

@ -1,7 +1,9 @@
extern crate tokenizers as tk; extern crate tokenizers as tk;
use crate::error::PyError;
use pyo3::prelude::*; use pyo3::prelude::*;
use pyo3::types::*; use pyo3::types::*;
use tk::tokenizer::PaddingDirection;
#[pyclass(dict)] #[pyclass(dict)]
#[repr(transparent)] #[repr(transparent)]
@ -89,4 +91,58 @@ impl Encoding {
fn get_overflowing(&self) -> Option<Encoding> { fn get_overflowing(&self) -> Option<Encoding> {
self.encoding.get_overflowing().cloned().map(Encoding::new) self.encoding.get_overflowing().cloned().map(Encoding::new)
} }
#[args(kwargs = "**")]
fn pad(&mut self, length: usize, kwargs: Option<&PyDict>) -> PyResult<()> {
let mut pad_id = 0;
let mut pad_type_id = 0;
let mut pad_token = "[PAD]";
let mut direction = PaddingDirection::Right;
if let Some(kwargs) = kwargs {
for (key, value) in kwargs {
let key: &str = key.extract()?;
match key {
"direction" => {
let value: &str = value.extract()?;
direction = match value {
"left" => Ok(PaddingDirection::Left),
"right" => Ok(PaddingDirection::Right),
other => Err(PyError(format!(
"Unknown `direction`: `{}`. Use \
one of `left` or `right`",
other
))
.into_pyerr()),
}?;
}
"pad_id" => pad_id = value.extract()?,
"pad_type_id" => pad_type_id = value.extract()?,
"pad_token" => pad_token = value.extract()?,
_ => println!("Ignored unknown kwarg option {}", key),
}
}
}
Ok(self
.encoding
.pad(length, pad_id, pad_type_id, pad_token, &direction))
}
#[args(kwargs = "**")]
fn truncate(&mut self, max_length: usize, kwargs: Option<&PyDict>) -> PyResult<()> {
let mut stride = 0;
if let Some(kwargs) = kwargs {
for (key, value) in kwargs {
let key: &str = key.extract()?;
match key {
"stride" => stride = value.extract()?,
_ => println!("Ignored unknown kwarg option {}", key),
}
}
}
Ok(self.encoding.truncate(max_length, stride))
}
} }