mirror of
https://github.com/mii443/tokenizers.git
synced 2025-09-01 14:59:20 +00:00
Give error when initializing tokenizer with too high stride (#1306)
* Split `get_n_added_tokens` into separate method * Modify `TokenizerImpl.with_truncation()` to raise an error if given bad parameters * Return Python error if `tokenizer.with_truncation()` fails * Add dummy variable assignment for `no_truncation()` case * Unrelated fmt fix. --------- Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
@ -712,15 +712,16 @@ impl PyTokenizer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
self.tokenizer.with_truncation(Some(params));
|
if let Err(error_message) = self.tokenizer.with_truncation(Some(params)) {
|
||||||
|
return Err(PyError(error_message.to_string()).into_pyerr::<exceptions::PyValueError>());
|
||||||
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Disable truncation
|
/// Disable truncation
|
||||||
#[pyo3(text_signature = "(self)")]
|
#[pyo3(text_signature = "(self)")]
|
||||||
fn no_truncation(&mut self) {
|
fn no_truncation(&mut self) {
|
||||||
self.tokenizer.with_truncation(None);
|
let _ = self.tokenizer.with_truncation(None);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get the currently set truncation parameters
|
/// Get the currently set truncation parameters
|
||||||
|
@ -497,6 +497,10 @@ impl DerefMut for Tokenizer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(thiserror::Error, Debug)]
|
||||||
|
#[error("{0}")]
|
||||||
|
pub struct TruncationParamError(String);
|
||||||
|
|
||||||
/// A `Tokenizer` is capable of encoding/decoding any text.
|
/// A `Tokenizer` is capable of encoding/decoding any text.
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct TokenizerImpl<M, N, PT, PP, D> {
|
pub struct TokenizerImpl<M, N, PT, PP, D> {
|
||||||
@ -595,9 +599,21 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Set the truncation parameters
|
/// Set the truncation parameters
|
||||||
pub fn with_truncation(&mut self, trunc: Option<TruncationParams>) -> &mut Self {
|
///
|
||||||
|
/// Fails if `stride` is too high relative to `max_length` and `post_processor.added_tokens()`
|
||||||
|
pub fn with_truncation(&mut self, trunc: Option<TruncationParams>) -> Result<&mut Self> {
|
||||||
|
if let Some(trunc_params) = &trunc {
|
||||||
|
let n_added_tokens = self.get_n_added_tokens(false);
|
||||||
|
let effective_max_length = trunc_params.max_length - n_added_tokens;
|
||||||
|
if effective_max_length <= trunc_params.stride {
|
||||||
|
return Err(Box::new(TruncationParamError(format!(
|
||||||
|
"tokenizer stride set to {}, which is greater than or equal to its effective max length of {} (= {} original max length - {} added special tokens), ",
|
||||||
|
trunc_params.stride, effective_max_length, trunc_params.max_length, n_added_tokens
|
||||||
|
))));
|
||||||
|
}
|
||||||
|
}
|
||||||
self.truncation = trunc;
|
self.truncation = trunc;
|
||||||
self
|
Ok(self)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get the currently set truncation parameters
|
/// Get the currently set truncation parameters
|
||||||
@ -902,11 +918,7 @@ where
|
|||||||
// 1. First we truncate if needed
|
// 1. First we truncate if needed
|
||||||
let (encoding, pair_encoding) = {
|
let (encoding, pair_encoding) = {
|
||||||
if let Some(trunc) = &self.truncation {
|
if let Some(trunc) = &self.truncation {
|
||||||
let n_added_tokens = if let Some(processor) = &self.post_processor {
|
let n_added_tokens = self.get_n_added_tokens(pair_encoding.is_some());
|
||||||
processor.added_tokens(pair_encoding.is_some())
|
|
||||||
} else {
|
|
||||||
0
|
|
||||||
};
|
|
||||||
|
|
||||||
if add_special_tokens && n_added_tokens > 0 {
|
if add_special_tokens && n_added_tokens > 0 {
|
||||||
let params = TruncationParams {
|
let params = TruncationParams {
|
||||||
@ -950,6 +962,14 @@ where
|
|||||||
|
|
||||||
Ok(final_encoding)
|
Ok(final_encoding)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn get_n_added_tokens(&self, is_pair: bool) -> usize {
|
||||||
|
if let Some(processor) = &self.post_processor {
|
||||||
|
processor.added_tokens(is_pair)
|
||||||
|
} else {
|
||||||
|
0
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<M, N, PT, PP, D> TokenizerImpl<M, N, PT, PP, D>
|
impl<M, N, PT, PP, D> TokenizerImpl<M, N, PT, PP, D>
|
||||||
|
Reference in New Issue
Block a user