From 4677a09626688c121f08af8b5e67c4fde397d5d4 Mon Sep 17 00:00:00 2001 From: Anthony MOI Date: Mon, 30 Dec 2019 12:56:07 -0500 Subject: [PATCH] Python - Expose pad and truncate on Encoding --- bindings/python/src/encoding.rs | 56 +++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/bindings/python/src/encoding.rs b/bindings/python/src/encoding.rs index f5d5e9d0..b979327e 100644 --- a/bindings/python/src/encoding.rs +++ b/bindings/python/src/encoding.rs @@ -1,7 +1,9 @@ extern crate tokenizers as tk; +use crate::error::PyError; use pyo3::prelude::*; use pyo3::types::*; +use tk::tokenizer::PaddingDirection; #[pyclass(dict)] #[repr(transparent)] @@ -89,4 +91,58 @@ impl Encoding { fn get_overflowing(&self) -> Option { self.encoding.get_overflowing().cloned().map(Encoding::new) } + + #[args(kwargs = "**")] + fn pad(&mut self, length: usize, kwargs: Option<&PyDict>) -> PyResult<()> { + let mut pad_id = 0; + let mut pad_type_id = 0; + let mut pad_token = "[PAD]"; + let mut direction = PaddingDirection::Right; + + if let Some(kwargs) = kwargs { + for (key, value) in kwargs { + let key: &str = key.extract()?; + match key { + "direction" => { + let value: &str = value.extract()?; + direction = match value { + "left" => Ok(PaddingDirection::Left), + "right" => Ok(PaddingDirection::Right), + other => Err(PyError(format!( + "Unknown `direction`: `{}`. Use \ + one of `left` or `right`", + other + )) + .into_pyerr()), + }?; + } + "pad_id" => pad_id = value.extract()?, + "pad_type_id" => pad_type_id = value.extract()?, + "pad_token" => pad_token = value.extract()?, + _ => println!("Ignored unknown kwarg option {}", key), + } + } + } + + Ok(self + .encoding + .pad(length, pad_id, pad_type_id, pad_token, &direction)) + } + + #[args(kwargs = "**")] + fn truncate(&mut self, max_length: usize, kwargs: Option<&PyDict>) -> PyResult<()> { + let mut stride = 0; + + if let Some(kwargs) = kwargs { + for (key, value) in kwargs { + let key: &str = key.extract()?; + match key { + "stride" => stride = value.extract()?, + _ => println!("Ignored unknown kwarg option {}", key), + } + } + } + + Ok(self.encoding.truncate(max_length, stride)) + } }