From 15085ef90582743b2ea1ca0ed1c864e5b946e9f4 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 4 May 2023 15:57:20 +0200 Subject: [PATCH] Fixing padding_left sequence_ids. (#1233) --- tokenizers/src/tokenizer/encoding.rs | 32 ++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tokenizers/src/tokenizer/encoding.rs b/tokenizers/src/tokenizer/encoding.rs index c6274c2f..40576efe 100644 --- a/tokenizers/src/tokenizer/encoding.rs +++ b/tokenizers/src/tokenizer/encoding.rs @@ -512,6 +512,11 @@ impl Encoding { .map(|_| (0, 0)) .chain(self.offsets.drain(..)) .collect(); + self.sequence_ranges + .iter_mut() + .for_each(|(_seq_id, range)| { + *range = (range.start + pad_length)..(range.end + pad_length) + }); } PaddingDirection::Right => { self.ids.extend((0..pad_length).map(|_| pad_id)); @@ -874,4 +879,31 @@ mod tests { assert_eq!(encoding.char_to_word(2, 1), Some(0)); assert_eq!(encoding.char_to_word(9, 1), Some(2)); } + + #[test] + fn padding() { + let mut a = Encoding { + ids: vec![1], + type_ids: vec![0], + tokens: vec![String::from("Hello ")], + words: vec![Some(0)], + offsets: vec![(0, 6)], + special_tokens_mask: vec![0], + attention_mask: vec![1], + sequence_ranges: HashMap::from([(0, 0..1)]), + ..Default::default() + }; + let target_length = 2; + let pad_id = 99; + let pad_type_id = 0; + let pad_token = "[PAD]"; + a.pad( + target_length, + pad_id, + pad_type_id, + pad_token, + PaddingDirection::Left, + ); + assert_eq!(a.sequence_ranges, HashMap::from([(0, 1..2)])); + } }