Adding truncation_side within TruncationParams. (#860)

* Add truncation to enable_truncation

* Fix typo

* Adding truncation_side within `TruncationParams`.

* Node serialization of this direction param.

* Update the test.

* Fixing warnings/lint.

* Adding stuff (can't local debug :( )

* Slow loop... ;(

* Stub.py.

Co-authored-by: Niels Rogge <niels.rogge1@gmail.com>
This commit is contained in:
Nicolas Patry
2021-12-28 12:37:06 +01:00
committed by GitHub
parent c4c9de23a5
commit 152880ab3e
16 changed files with 478 additions and 26 deletions

View File

@ -10,7 +10,7 @@ use pyo3::PyObjectProtocol;
use tk::models::bpe::BPE;
use tk::tokenizer::{
Model, PaddingDirection, PaddingParams, PaddingStrategy, PostProcessor, TokenizerImpl,
TruncationParams, TruncationStrategy,
TruncationDirection, TruncationParams, TruncationStrategy,
};
use tk::utils::iter::ResultShunt;
use tokenizers as tk;
@ -660,8 +660,11 @@ impl PyTokenizer {
/// strategy (:obj:`str`, `optional`, defaults to :obj:`longest_first`):
/// The strategy used to truncation. Can be one of ``longest_first``, ``only_first`` or
/// ``only_second``.
///
/// direction (:obj:`str`, defaults to :obj:`right`):
/// Truncate direction
#[args(kwargs = "**")]
#[text_signature = "(self, max_length, stride=0, strategy='longest_first')"]
#[text_signature = "(self, max_length, stride=0, strategy='longest_first', direction='right')"]
fn enable_truncation(&mut self, max_length: usize, kwargs: Option<&PyDict>) -> PyResult<()> {
let mut params = TruncationParams {
max_length,
@ -687,6 +690,19 @@ impl PyTokenizer {
.into_pyerr::<exceptions::PyValueError>()),
}?
}
"direction" => {
let value: &str = value.extract()?;
params.direction = match value {
"left" => Ok(TruncationDirection::Left),
"right" => Ok(TruncationDirection::Right),
_ => Err(PyError(format!(
"Unknown `direction`: `{}`. Use \
one of `left` or `right`.",
value
))
.into_pyerr::<exceptions::PyValueError>()),
}?
}
_ => println!("Ignored unknown kwarg option {}", key),
}
}
@ -718,6 +734,7 @@ impl PyTokenizer {
dict.set_item("max_length", params.max_length)?;
dict.set_item("stride", params.stride)?;
dict.set_item("strategy", params.strategy.as_ref())?;
dict.set_item("direction", params.direction.as_ref())?;
Ok(Some(dict))
})