diff --git a/bindings/python/CHANGELOG.md b/bindings/python/CHANGELOG.md index 577925bc..344e8df8 100644 --- a/bindings/python/CHANGELOG.md +++ b/bindings/python/CHANGELOG.md @@ -4,6 +4,7 @@ Fixes: - Some default tokens were missing from `BertWordPieceTokenizer` (cf [#160](https://github.com/huggingface/tokenizers/issues/160)) - There was a bug in ByteLevel PreTokenizer that caused offsets to be wrong if a char got split up in multiple bytes. (cf [#156](https://github.com/huggingface/tokenizers/pull/156)) +- The `longest_first` truncation strategy had a bug ([#174](https://github.com/huggingface/tokenizers/issues/174)) # v0.5.2 - Do not open all files directly while training ([#163](https://github.com/huggingface/tokenizers/issues/163)) diff --git a/tokenizers/CHANGELOG.md b/tokenizers/CHANGELOG.md index e1fea14b..b0f29a84 100644 --- a/tokenizers/CHANGELOG.md +++ b/tokenizers/CHANGELOG.md @@ -4,3 +4,4 @@ - Do not open all files directly while training ([#163](https://github.com/huggingface/tokenizers/issues/163)) - There was a bug in ByteLevel PreTokenizer that caused offsets to be wrong if a char got split up in multiple bytes. (cf [#156](https://github.com/huggingface/tokenizers/pull/156)) +- The `LongestFirst` truncation strategy had a bug ([#174](https://github.com/huggingface/tokenizers/issues/174)) diff --git a/tokenizers/src/tokenizer/encoding.rs b/tokenizers/src/tokenizer/encoding.rs index 8ee58fee..c9a29e0c 100644 --- a/tokenizers/src/tokenizer/encoding.rs +++ b/tokenizers/src/tokenizer/encoding.rs @@ -91,11 +91,13 @@ impl Encoding { /// Truncate the current `Encoding`. /// - /// Panic if `stride >= max_len`. + /// Panic if `stride >= max_len` or `max_len == 0`. pub fn truncate(&mut self, max_len: usize, stride: usize) { if max_len >= self.ids.len() { return; } + // We only truncate if max_len > 0, it makes no sense otherwise + assert!(max_len > 0); // Get the main overflowing part let o_ids = self.ids.split_off(max_len); @@ -105,7 +107,7 @@ impl Encoding { let o_spe_toks = self.special_tokens_mask.split_off(max_len); let o_attent = self.attention_mask.split_off(max_len); - // Now we need to separate each overflowing part into as many Encoding as needed + // Now we need to separate the overflowing part into as many Encoding as needed assert!(stride < max_len); let part_size = max_len - stride; let mut overflowing = vec![]; diff --git a/tokenizers/src/utils.rs b/tokenizers/src/utils.rs index 2a8c187a..efc3ca5e 100644 --- a/tokenizers/src/utils.rs +++ b/tokenizers/src/utils.rs @@ -83,21 +83,19 @@ pub fn truncate_encodings( match params.strategy { TruncationStrategy::LongestFirst => { - let mut n_first = 0; - let mut n_second = 0; + let mut n_first = encoding.get_ids().len(); + let mut n_second = pair_encoding.as_ref().map_or(0, |e| e.get_ids().len()); for _ in 0..to_remove { - if pair_encoding.is_none() - || encoding.get_ids().len() > pair_encoding.as_ref().unwrap().get_ids().len() - { - n_first += 1; + if n_first > n_second { + n_first -= 1; } else { - n_second += 1; + n_second -= 1; } } - encoding.truncate(encoding.get_ids().len() - n_first, params.stride); + encoding.truncate(n_first, params.stride); if let Some(encoding) = pair_encoding.as_mut() { - encoding.truncate(encoding.get_ids().len() - n_second, params.stride); + encoding.truncate(n_second, params.stride); } } TruncationStrategy::OnlyFirst | TruncationStrategy::OnlySecond => {