mirror of
https://github.com/mii443/tokenizers.git
synced 2025-12-07 21:28:19 +00:00
Fix LongestFirst truncation strategy
This commit is contained in:
@@ -4,6 +4,7 @@ Fixes:
|
||||
- Some default tokens were missing from `BertWordPieceTokenizer` (cf [#160](https://github.com/huggingface/tokenizers/issues/160))
|
||||
- There was a bug in ByteLevel PreTokenizer that caused offsets to be wrong if a char got split up
|
||||
in multiple bytes. (cf [#156](https://github.com/huggingface/tokenizers/pull/156))
|
||||
- The `longest_first` truncation strategy had a bug ([#174](https://github.com/huggingface/tokenizers/issues/174))
|
||||
|
||||
# v0.5.2
|
||||
- Do not open all files directly while training ([#163](https://github.com/huggingface/tokenizers/issues/163))
|
||||
|
||||
@@ -4,3 +4,4 @@
|
||||
- Do not open all files directly while training ([#163](https://github.com/huggingface/tokenizers/issues/163))
|
||||
- There was a bug in ByteLevel PreTokenizer that caused offsets to be wrong if a char got split up
|
||||
in multiple bytes. (cf [#156](https://github.com/huggingface/tokenizers/pull/156))
|
||||
- The `LongestFirst` truncation strategy had a bug ([#174](https://github.com/huggingface/tokenizers/issues/174))
|
||||
|
||||
@@ -91,11 +91,13 @@ impl Encoding {
|
||||
|
||||
/// Truncate the current `Encoding`.
|
||||
///
|
||||
/// Panic if `stride >= max_len`.
|
||||
/// Panic if `stride >= max_len` or `max_len == 0`.
|
||||
pub fn truncate(&mut self, max_len: usize, stride: usize) {
|
||||
if max_len >= self.ids.len() {
|
||||
return;
|
||||
}
|
||||
// We only truncate if max_len > 0, it makes no sense otherwise
|
||||
assert!(max_len > 0);
|
||||
|
||||
// Get the main overflowing part
|
||||
let o_ids = self.ids.split_off(max_len);
|
||||
@@ -105,7 +107,7 @@ impl Encoding {
|
||||
let o_spe_toks = self.special_tokens_mask.split_off(max_len);
|
||||
let o_attent = self.attention_mask.split_off(max_len);
|
||||
|
||||
// Now we need to separate each overflowing part into as many Encoding as needed
|
||||
// Now we need to separate the overflowing part into as many Encoding as needed
|
||||
assert!(stride < max_len);
|
||||
let part_size = max_len - stride;
|
||||
let mut overflowing = vec![];
|
||||
|
||||
@@ -83,21 +83,19 @@ pub fn truncate_encodings(
|
||||
|
||||
match params.strategy {
|
||||
TruncationStrategy::LongestFirst => {
|
||||
let mut n_first = 0;
|
||||
let mut n_second = 0;
|
||||
let mut n_first = encoding.get_ids().len();
|
||||
let mut n_second = pair_encoding.as_ref().map_or(0, |e| e.get_ids().len());
|
||||
for _ in 0..to_remove {
|
||||
if pair_encoding.is_none()
|
||||
|| encoding.get_ids().len() > pair_encoding.as_ref().unwrap().get_ids().len()
|
||||
{
|
||||
n_first += 1;
|
||||
if n_first > n_second {
|
||||
n_first -= 1;
|
||||
} else {
|
||||
n_second += 1;
|
||||
n_second -= 1;
|
||||
}
|
||||
}
|
||||
|
||||
encoding.truncate(encoding.get_ids().len() - n_first, params.stride);
|
||||
encoding.truncate(n_first, params.stride);
|
||||
if let Some(encoding) = pair_encoding.as_mut() {
|
||||
encoding.truncate(encoding.get_ids().len() - n_second, params.stride);
|
||||
encoding.truncate(n_second, params.stride);
|
||||
}
|
||||
}
|
||||
TruncationStrategy::OnlyFirst | TruncationStrategy::OnlySecond => {
|
||||
|
||||
Reference in New Issue
Block a user