Fix LongestFirst truncation strategy

This commit is contained in:
Anthony MOI
2020-02-29 16:26:13 -05:00
parent 2f85ba21e6
commit f8f0702d98
4 changed files with 13 additions and 11 deletions

View File

@@ -4,6 +4,7 @@ Fixes:
- Some default tokens were missing from `BertWordPieceTokenizer` (cf [#160](https://github.com/huggingface/tokenizers/issues/160))
- There was a bug in ByteLevel PreTokenizer that caused offsets to be wrong if a char got split up
in multiple bytes. (cf [#156](https://github.com/huggingface/tokenizers/pull/156))
- The `longest_first` truncation strategy had a bug ([#174](https://github.com/huggingface/tokenizers/issues/174))
# v0.5.2
- Do not open all files directly while training ([#163](https://github.com/huggingface/tokenizers/issues/163))

View File

@@ -4,3 +4,4 @@
- Do not open all files directly while training ([#163](https://github.com/huggingface/tokenizers/issues/163))
- There was a bug in ByteLevel PreTokenizer that caused offsets to be wrong if a char got split up
in multiple bytes. (cf [#156](https://github.com/huggingface/tokenizers/pull/156))
- The `LongestFirst` truncation strategy had a bug ([#174](https://github.com/huggingface/tokenizers/issues/174))

View File

@@ -91,11 +91,13 @@ impl Encoding {
/// Truncate the current `Encoding`.
///
/// Panic if `stride >= max_len`.
/// Panic if `stride >= max_len` or `max_len == 0`.
pub fn truncate(&mut self, max_len: usize, stride: usize) {
if max_len >= self.ids.len() {
return;
}
// We only truncate if max_len > 0, it makes no sense otherwise
assert!(max_len > 0);
// Get the main overflowing part
let o_ids = self.ids.split_off(max_len);
@@ -105,7 +107,7 @@ impl Encoding {
let o_spe_toks = self.special_tokens_mask.split_off(max_len);
let o_attent = self.attention_mask.split_off(max_len);
// Now we need to separate each overflowing part into as many Encoding as needed
// Now we need to separate the overflowing part into as many Encoding as needed
assert!(stride < max_len);
let part_size = max_len - stride;
let mut overflowing = vec![];

View File

@@ -83,21 +83,19 @@ pub fn truncate_encodings(
match params.strategy {
TruncationStrategy::LongestFirst => {
let mut n_first = 0;
let mut n_second = 0;
let mut n_first = encoding.get_ids().len();
let mut n_second = pair_encoding.as_ref().map_or(0, |e| e.get_ids().len());
for _ in 0..to_remove {
if pair_encoding.is_none()
|| encoding.get_ids().len() > pair_encoding.as_ref().unwrap().get_ids().len()
{
n_first += 1;
if n_first > n_second {
n_first -= 1;
} else {
n_second += 1;
n_second -= 1;
}
}
encoding.truncate(encoding.get_ids().len() - n_first, params.stride);
encoding.truncate(n_first, params.stride);
if let Some(encoding) = pair_encoding.as_mut() {
encoding.truncate(encoding.get_ids().len() - n_second, params.stride);
encoding.truncate(n_second, params.stride);
}
}
TruncationStrategy::OnlyFirst | TruncationStrategy::OnlySecond => {