Fixing padding_left sequence_ids. (#1233)

This commit is contained in:
Nicolas Patry
2023-05-04 15:57:20 +02:00
committed by GitHub
parent ef5f50605d
commit 15085ef905

View File

@ -512,6 +512,11 @@ impl Encoding {
.map(|_| (0, 0))
.chain(self.offsets.drain(..))
.collect();
self.sequence_ranges
.iter_mut()
.for_each(|(_seq_id, range)| {
*range = (range.start + pad_length)..(range.end + pad_length)
});
}
PaddingDirection::Right => {
self.ids.extend((0..pad_length).map(|_| pad_id));
@ -874,4 +879,31 @@ mod tests {
assert_eq!(encoding.char_to_word(2, 1), Some(0));
assert_eq!(encoding.char_to_word(9, 1), Some(2));
}
#[test]
fn padding() {
let mut a = Encoding {
ids: vec![1],
type_ids: vec![0],
tokens: vec![String::from("Hello ")],
words: vec![Some(0)],
offsets: vec![(0, 6)],
special_tokens_mask: vec![0],
attention_mask: vec![1],
sequence_ranges: HashMap::from([(0, 0..1)]),
..Default::default()
};
let target_length = 2;
let pad_id = 99;
let pad_type_id = 0;
let pad_token = "[PAD]";
a.pad(
target_length,
pad_id,
pad_type_id,
pad_token,
PaddingDirection::Left,
);
assert_eq!(a.sequence_ranges, HashMap::from([(0, 1..2)]));
}
}