mirror of
https://github.com/mii443/tokenizers.git
synced 2025-09-01 14:59:20 +00:00
Give error when initializing tokenizer with too high stride (#1306)
* Split `get_n_added_tokens` into separate method * Modify `TokenizerImpl.with_truncation()` to raise an error if given bad parameters * Return Python error if `tokenizer.with_truncation()` fails * Add dummy variable assignment for `no_truncation()` case * Unrelated fmt fix. --------- Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
@ -712,15 +712,16 @@ impl PyTokenizer {
|
||||
}
|
||||
}
|
||||
|
||||
self.tokenizer.with_truncation(Some(params));
|
||||
|
||||
if let Err(error_message) = self.tokenizer.with_truncation(Some(params)) {
|
||||
return Err(PyError(error_message.to_string()).into_pyerr::<exceptions::PyValueError>());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Disable truncation
|
||||
#[pyo3(text_signature = "(self)")]
|
||||
fn no_truncation(&mut self) {
|
||||
self.tokenizer.with_truncation(None);
|
||||
let _ = self.tokenizer.with_truncation(None);
|
||||
}
|
||||
|
||||
/// Get the currently set truncation parameters
|
||||
|
@ -497,6 +497,10 @@ impl DerefMut for Tokenizer {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
#[error("{0}")]
|
||||
pub struct TruncationParamError(String);
|
||||
|
||||
/// A `Tokenizer` is capable of encoding/decoding any text.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct TokenizerImpl<M, N, PT, PP, D> {
|
||||
@ -595,9 +599,21 @@ where
|
||||
}
|
||||
|
||||
/// Set the truncation parameters
|
||||
pub fn with_truncation(&mut self, trunc: Option<TruncationParams>) -> &mut Self {
|
||||
///
|
||||
/// Fails if `stride` is too high relative to `max_length` and `post_processor.added_tokens()`
|
||||
pub fn with_truncation(&mut self, trunc: Option<TruncationParams>) -> Result<&mut Self> {
|
||||
if let Some(trunc_params) = &trunc {
|
||||
let n_added_tokens = self.get_n_added_tokens(false);
|
||||
let effective_max_length = trunc_params.max_length - n_added_tokens;
|
||||
if effective_max_length <= trunc_params.stride {
|
||||
return Err(Box::new(TruncationParamError(format!(
|
||||
"tokenizer stride set to {}, which is greater than or equal to its effective max length of {} (= {} original max length - {} added special tokens), ",
|
||||
trunc_params.stride, effective_max_length, trunc_params.max_length, n_added_tokens
|
||||
))));
|
||||
}
|
||||
}
|
||||
self.truncation = trunc;
|
||||
self
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
/// Get the currently set truncation parameters
|
||||
@ -902,11 +918,7 @@ where
|
||||
// 1. First we truncate if needed
|
||||
let (encoding, pair_encoding) = {
|
||||
if let Some(trunc) = &self.truncation {
|
||||
let n_added_tokens = if let Some(processor) = &self.post_processor {
|
||||
processor.added_tokens(pair_encoding.is_some())
|
||||
} else {
|
||||
0
|
||||
};
|
||||
let n_added_tokens = self.get_n_added_tokens(pair_encoding.is_some());
|
||||
|
||||
if add_special_tokens && n_added_tokens > 0 {
|
||||
let params = TruncationParams {
|
||||
@ -950,6 +962,14 @@ where
|
||||
|
||||
Ok(final_encoding)
|
||||
}
|
||||
|
||||
fn get_n_added_tokens(&self, is_pair: bool) -> usize {
|
||||
if let Some(processor) = &self.post_processor {
|
||||
processor.added_tokens(is_pair)
|
||||
} else {
|
||||
0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<M, N, PT, PP, D> TokenizerImpl<M, N, PT, PP, D>
|
||||
|
Reference in New Issue
Block a user