From 68f99bb8220cf73f4977cd67eff6f63f5e24eecf Mon Sep 17 00:00:00 2001 From: Anthony MOI Date: Wed, 15 Jan 2020 17:01:04 -0500 Subject: [PATCH] Improve the truncation of an Encoding --- tokenizers/src/tokenizer/encoding.rs | 156 ++++++++++++++++----------- tokenizers/src/tokenizer/mod.rs | 4 +- 2 files changed, 98 insertions(+), 62 deletions(-) diff --git a/tokenizers/src/tokenizer/encoding.rs b/tokenizers/src/tokenizer/encoding.rs index fdb430fe..07d6bae7 100644 --- a/tokenizers/src/tokenizer/encoding.rs +++ b/tokenizers/src/tokenizer/encoding.rs @@ -17,7 +17,7 @@ pub struct Encoding { offsets: Vec<(usize, usize)>, special_tokens_mask: Vec, attention_mask: Vec, - overflowing: Option>, + overflowing: Vec, } impl Encoding { #[allow(clippy::too_many_arguments)] @@ -29,7 +29,7 @@ impl Encoding { offsets: Vec<(usize, usize)>, special_tokens_mask: Vec, attention_mask: Vec, - overflowing: Option>, + overflowing: Vec, ) -> Self { Encoding { normalized, @@ -71,48 +71,89 @@ impl Encoding { &self.attention_mask } - pub fn get_overflowing(&self) -> Option<&Encoding> { - self.overflowing.as_ref().map(|b| &**b) + pub fn get_overflowing(&self) -> &Vec { + &self.overflowing } - pub fn take_overflowing(&mut self) -> Option> { - self.overflowing.take() + pub fn take_overflowing(&mut self) -> Vec { + std::mem::replace(&mut self.overflowing, vec![]) } + /// Truncate the current `Encoding`. + /// + /// Panic if `stride >= max_len`. pub fn truncate(&mut self, max_len: usize, stride: usize) { if max_len >= self.ids.len() { return; } - let mut o_ids = self.ids.split_off(max_len); - let mut o_type_ids = self.type_ids.split_off(max_len); - let mut o_tokens = self.tokens.split_off(max_len); - let mut o_offsets = self.offsets.split_off(max_len); - let mut o_spe_toks = self.special_tokens_mask.split_off(max_len); - let mut o_attent = self.attention_mask.split_off(max_len); + // Get the main overflowing part + let o_ids = self.ids.split_off(max_len); + let o_type_ids = self.type_ids.split_off(max_len); + let o_tokens = self.tokens.split_off(max_len); + let o_offsets = self.offsets.split_off(max_len); + let o_spe_toks = self.special_tokens_mask.split_off(max_len); + let o_attent = self.attention_mask.split_off(max_len); - let max = self.offsets.last().map(|(_, end)| *end).unwrap_or(0); - let trunc_normalized = self.normalized.split_off(max); + // Now we need to separate each overflowing part into as many Encoding as needed + assert!(stride < max_len); + let part_size = max_len - stride; + let mut overflowing = vec![]; + let mut part_id = 0; + let mut prev_encoding: &Encoding = self; - if stride > 0 { - o_ids = prepend_stride(&self.ids, o_ids, stride); - o_type_ids = prepend_stride(&self.type_ids, o_type_ids, stride); - o_tokens = prepend_stride(&self.tokens, o_tokens, stride); - o_offsets = prepend_stride(&self.offsets, o_offsets, stride); - o_spe_toks = prepend_stride(&self.special_tokens_mask, o_spe_toks, stride); - o_attent = prepend_stride(&self.attention_mask, o_attent, stride); + loop { + if part_size * part_id >= o_ids.len() { + break; + } + + let o = Encoding { + normalized: self.normalized.clone(), + ids: get_current_part(&prev_encoding.ids, &o_ids, part_size, part_id, stride), + type_ids: get_current_part( + &prev_encoding.type_ids, + &o_type_ids, + part_size, + part_id, + stride, + ), + tokens: get_current_part( + &prev_encoding.tokens, + &o_tokens, + part_size, + part_id, + stride, + ), + offsets: get_current_part( + &prev_encoding.offsets, + &o_offsets, + part_size, + part_id, + stride, + ), + special_tokens_mask: get_current_part( + &prev_encoding.special_tokens_mask, + &o_spe_toks, + part_size, + part_id, + stride, + ), + attention_mask: get_current_part( + &prev_encoding.attention_mask, + &o_attent, + part_size, + part_id, + stride, + ), + overflowing: vec![], + }; + + part_id += 1; + overflowing.push(o); + prev_encoding = &overflowing.last().unwrap(); } - self.overflowing = Some(Box::new(Encoding { - normalized: trunc_normalized, - ids: o_ids, - type_ids: o_type_ids, - tokens: o_tokens, - offsets: o_offsets, - special_tokens_mask: o_spe_toks, - attention_mask: o_attent, - overflowing: None, - })); + self.overflowing = overflowing; } pub fn merge_with(&mut self, pair: Encoding) { @@ -180,32 +221,27 @@ impl Encoding { } } -/// Prepend the `stride` last elements of the `previous` `Vec` to the current `Vec`. -// A new Vec is instantiated though. -fn prepend_stride(previous: &[T], current: Vec, stride: usize) -> Vec { - let prev = previous - .iter() - .rev() - .take(stride) - .cloned() - .rev() - .collect::>(); - - [&prev[..], ¤t[..]].concat() +#[inline] +fn get_current_part( + prev: &[T], + current: &[T], + size: usize, + idx: usize, + stride: usize, +) -> Vec { + let curr_slice = if (idx + 1) * size > current.len() { + ¤t[idx * size..] + } else { + ¤t[idx * size..(idx + 1) * size] + }; + let prev_slice = &prev[prev.len() - stride..]; + [prev_slice, curr_slice].concat() } #[cfg(test)] mod tests { use super::*; - #[test] - fn test_prepend_stride() { - let prev = vec![1, 2, 3, 4, 5, 6, 7, 8]; - let curr = vec![9, 10, 11, 12]; - - assert_eq!(prepend_stride(&prev, curr, 3), vec![6, 7, 8, 9, 10, 11, 12]); - } - #[test] fn merge_encodings() { let mut a = Encoding { @@ -216,7 +252,7 @@ mod tests { offsets: vec![(0, 6)], special_tokens_mask: vec![0], attention_mask: vec![1], - overflowing: None, + overflowing: vec![], }; let b = Encoding { normalized: NormalizedString::from("World!"), @@ -226,7 +262,7 @@ mod tests { offsets: vec![(0, 6)], special_tokens_mask: vec![0], attention_mask: vec![1], - overflowing: None, + overflowing: vec![], }; a.merge_with(b); @@ -240,7 +276,7 @@ mod tests { offsets: vec![(0, 6), (6, 12)], special_tokens_mask: vec![0, 0], attention_mask: vec![1, 1], - overflowing: None, + overflowing: vec![], } ); } @@ -259,30 +295,30 @@ mod tests { offsets: vec![(0, 5), (6, 11), (11, 12)], special_tokens_mask: vec![0, 0, 0], attention_mask: vec![1, 1, 1], - overflowing: None, + overflowing: vec![], }; a.truncate(2, 0); assert_eq!( a, Encoding { - normalized: NormalizedString::from("Hello World"), + normalized: NormalizedString::from("Hello World!"), ids: vec![1, 2], type_ids: vec![0, 0], tokens: vec![String::from("Hello"), String::from("World")], offsets: vec![(0, 5), (6, 11)], special_tokens_mask: vec![0, 0], attention_mask: vec![1, 1], - overflowing: Some(Box::new(Encoding { - normalized: NormalizedString::from("!"), + overflowing: vec![Encoding { + normalized: NormalizedString::from("Hello World!"), ids: vec![3], type_ids: vec![0], tokens: vec![String::from("!")], offsets: vec![(11, 12)], special_tokens_mask: vec![0], attention_mask: vec![1], - overflowing: None, - })) + overflowing: vec![], + }] } ); } diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index cda330ad..6ed8bb76 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -285,7 +285,7 @@ impl Tokenizer { vec![(0, sentence.len())], vec![0], vec![1], - None, + vec![], )); } @@ -321,7 +321,7 @@ impl Tokenizer { offsets, vec![0; length], vec![1; length], - None, + vec![], )) }) .collect::>>()?;