mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Python - Expose pad and truncate on Encoding
This commit is contained in:
@ -1,7 +1,9 @@
|
|||||||
extern crate tokenizers as tk;
|
extern crate tokenizers as tk;
|
||||||
|
|
||||||
|
use crate::error::PyError;
|
||||||
use pyo3::prelude::*;
|
use pyo3::prelude::*;
|
||||||
use pyo3::types::*;
|
use pyo3::types::*;
|
||||||
|
use tk::tokenizer::PaddingDirection;
|
||||||
|
|
||||||
#[pyclass(dict)]
|
#[pyclass(dict)]
|
||||||
#[repr(transparent)]
|
#[repr(transparent)]
|
||||||
@ -89,4 +91,58 @@ impl Encoding {
|
|||||||
fn get_overflowing(&self) -> Option<Encoding> {
|
fn get_overflowing(&self) -> Option<Encoding> {
|
||||||
self.encoding.get_overflowing().cloned().map(Encoding::new)
|
self.encoding.get_overflowing().cloned().map(Encoding::new)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[args(kwargs = "**")]
|
||||||
|
fn pad(&mut self, length: usize, kwargs: Option<&PyDict>) -> PyResult<()> {
|
||||||
|
let mut pad_id = 0;
|
||||||
|
let mut pad_type_id = 0;
|
||||||
|
let mut pad_token = "[PAD]";
|
||||||
|
let mut direction = PaddingDirection::Right;
|
||||||
|
|
||||||
|
if let Some(kwargs) = kwargs {
|
||||||
|
for (key, value) in kwargs {
|
||||||
|
let key: &str = key.extract()?;
|
||||||
|
match key {
|
||||||
|
"direction" => {
|
||||||
|
let value: &str = value.extract()?;
|
||||||
|
direction = match value {
|
||||||
|
"left" => Ok(PaddingDirection::Left),
|
||||||
|
"right" => Ok(PaddingDirection::Right),
|
||||||
|
other => Err(PyError(format!(
|
||||||
|
"Unknown `direction`: `{}`. Use \
|
||||||
|
one of `left` or `right`",
|
||||||
|
other
|
||||||
|
))
|
||||||
|
.into_pyerr()),
|
||||||
|
}?;
|
||||||
|
}
|
||||||
|
"pad_id" => pad_id = value.extract()?,
|
||||||
|
"pad_type_id" => pad_type_id = value.extract()?,
|
||||||
|
"pad_token" => pad_token = value.extract()?,
|
||||||
|
_ => println!("Ignored unknown kwarg option {}", key),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(self
|
||||||
|
.encoding
|
||||||
|
.pad(length, pad_id, pad_type_id, pad_token, &direction))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[args(kwargs = "**")]
|
||||||
|
fn truncate(&mut self, max_length: usize, kwargs: Option<&PyDict>) -> PyResult<()> {
|
||||||
|
let mut stride = 0;
|
||||||
|
|
||||||
|
if let Some(kwargs) = kwargs {
|
||||||
|
for (key, value) in kwargs {
|
||||||
|
let key: &str = key.extract()?;
|
||||||
|
match key {
|
||||||
|
"stride" => stride = value.extract()?,
|
||||||
|
_ => println!("Ignored unknown kwarg option {}", key),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(self.encoding.truncate(max_length, stride))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user