Give error when initializing tokenizer with too high stride (#1306)

* Split `get_n_added_tokens` into separate method

* Modify `TokenizerImpl.with_truncation()` to raise an error if given bad parameters

* Return Python error if `tokenizer.with_truncation()` fails

* Add dummy variable assignment for `no_truncation()` case

* Unrelated fmt fix.

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
Connor Boyle
2023-07-28 00:16:44 -07:00
committed by GitHub
parent bb38f390a6
commit c2664ae13f
3 changed files with 38 additions and 17 deletions

View File

@ -712,15 +712,16 @@ impl PyTokenizer {
} }
} }
self.tokenizer.with_truncation(Some(params)); if let Err(error_message) = self.tokenizer.with_truncation(Some(params)) {
return Err(PyError(error_message.to_string()).into_pyerr::<exceptions::PyValueError>());
}
Ok(()) Ok(())
} }
/// Disable truncation /// Disable truncation
#[pyo3(text_signature = "(self)")] #[pyo3(text_signature = "(self)")]
fn no_truncation(&mut self) { fn no_truncation(&mut self) {
self.tokenizer.with_truncation(None); let _ = self.tokenizer.with_truncation(None);
} }
/// Get the currently set truncation parameters /// Get the currently set truncation parameters

View File

@ -497,6 +497,10 @@ impl DerefMut for Tokenizer {
} }
} }
#[derive(thiserror::Error, Debug)]
#[error("{0}")]
pub struct TruncationParamError(String);
/// A `Tokenizer` is capable of encoding/decoding any text. /// A `Tokenizer` is capable of encoding/decoding any text.
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct TokenizerImpl<M, N, PT, PP, D> { pub struct TokenizerImpl<M, N, PT, PP, D> {
@ -595,9 +599,21 @@ where
} }
/// Set the truncation parameters /// Set the truncation parameters
pub fn with_truncation(&mut self, trunc: Option<TruncationParams>) -> &mut Self { ///
/// Fails if `stride` is too high relative to `max_length` and `post_processor.added_tokens()`
pub fn with_truncation(&mut self, trunc: Option<TruncationParams>) -> Result<&mut Self> {
if let Some(trunc_params) = &trunc {
let n_added_tokens = self.get_n_added_tokens(false);
let effective_max_length = trunc_params.max_length - n_added_tokens;
if effective_max_length <= trunc_params.stride {
return Err(Box::new(TruncationParamError(format!(
"tokenizer stride set to {}, which is greater than or equal to its effective max length of {} (= {} original max length - {} added special tokens), ",
trunc_params.stride, effective_max_length, trunc_params.max_length, n_added_tokens
))));
}
}
self.truncation = trunc; self.truncation = trunc;
self Ok(self)
} }
/// Get the currently set truncation parameters /// Get the currently set truncation parameters
@ -902,11 +918,7 @@ where
// 1. First we truncate if needed // 1. First we truncate if needed
let (encoding, pair_encoding) = { let (encoding, pair_encoding) = {
if let Some(trunc) = &self.truncation { if let Some(trunc) = &self.truncation {
let n_added_tokens = if let Some(processor) = &self.post_processor { let n_added_tokens = self.get_n_added_tokens(pair_encoding.is_some());
processor.added_tokens(pair_encoding.is_some())
} else {
0
};
if add_special_tokens && n_added_tokens > 0 { if add_special_tokens && n_added_tokens > 0 {
let params = TruncationParams { let params = TruncationParams {
@ -950,6 +962,14 @@ where
Ok(final_encoding) Ok(final_encoding)
} }
fn get_n_added_tokens(&self, is_pair: bool) -> usize {
if let Some(processor) = &self.post_processor {
processor.added_tokens(is_pair)
} else {
0
}
}
} }
impl<M, N, PT, PP, D> TokenizerImpl<M, N, PT, PP, D> impl<M, N, PT, PP, D> TokenizerImpl<M, N, PT, PP, D>