Improve the truncation of an Encoding

This commit is contained in:
Anthony MOI
2020-01-15 17:01:04 -05:00
committed by Pierric Cistac
parent 78e26905a7
commit 68f99bb822
2 changed files with 98 additions and 62 deletions

View File

@@ -17,7 +17,7 @@ pub struct Encoding {
offsets: Vec<(usize, usize)>,
special_tokens_mask: Vec<u32>,
attention_mask: Vec<u32>,
overflowing: Option<Box<Encoding>>,
overflowing: Vec<Encoding>,
}
impl Encoding {
#[allow(clippy::too_many_arguments)]
@@ -29,7 +29,7 @@ impl Encoding {
offsets: Vec<(usize, usize)>,
special_tokens_mask: Vec<u32>,
attention_mask: Vec<u32>,
overflowing: Option<Box<Encoding>>,
overflowing: Vec<Encoding>,
) -> Self {
Encoding {
normalized,
@@ -71,48 +71,89 @@ impl Encoding {
&self.attention_mask
}
pub fn get_overflowing(&self) -> Option<&Encoding> {
self.overflowing.as_ref().map(|b| &**b)
pub fn get_overflowing(&self) -> &Vec<Encoding> {
&self.overflowing
}
pub fn take_overflowing(&mut self) -> Option<Box<Encoding>> {
self.overflowing.take()
pub fn take_overflowing(&mut self) -> Vec<Encoding> {
std::mem::replace(&mut self.overflowing, vec![])
}
/// Truncate the current `Encoding`.
///
/// Panic if `stride >= max_len`.
pub fn truncate(&mut self, max_len: usize, stride: usize) {
if max_len >= self.ids.len() {
return;
}
let mut o_ids = self.ids.split_off(max_len);
let mut o_type_ids = self.type_ids.split_off(max_len);
let mut o_tokens = self.tokens.split_off(max_len);
let mut o_offsets = self.offsets.split_off(max_len);
let mut o_spe_toks = self.special_tokens_mask.split_off(max_len);
let mut o_attent = self.attention_mask.split_off(max_len);
// Get the main overflowing part
let o_ids = self.ids.split_off(max_len);
let o_type_ids = self.type_ids.split_off(max_len);
let o_tokens = self.tokens.split_off(max_len);
let o_offsets = self.offsets.split_off(max_len);
let o_spe_toks = self.special_tokens_mask.split_off(max_len);
let o_attent = self.attention_mask.split_off(max_len);
let max = self.offsets.last().map(|(_, end)| *end).unwrap_or(0);
let trunc_normalized = self.normalized.split_off(max);
// Now we need to separate each overflowing part into as many Encoding as needed
assert!(stride < max_len);
let part_size = max_len - stride;
let mut overflowing = vec![];
let mut part_id = 0;
let mut prev_encoding: &Encoding = self;
if stride > 0 {
o_ids = prepend_stride(&self.ids, o_ids, stride);
o_type_ids = prepend_stride(&self.type_ids, o_type_ids, stride);
o_tokens = prepend_stride(&self.tokens, o_tokens, stride);
o_offsets = prepend_stride(&self.offsets, o_offsets, stride);
o_spe_toks = prepend_stride(&self.special_tokens_mask, o_spe_toks, stride);
o_attent = prepend_stride(&self.attention_mask, o_attent, stride);
loop {
if part_size * part_id >= o_ids.len() {
break;
}
let o = Encoding {
normalized: self.normalized.clone(),
ids: get_current_part(&prev_encoding.ids, &o_ids, part_size, part_id, stride),
type_ids: get_current_part(
&prev_encoding.type_ids,
&o_type_ids,
part_size,
part_id,
stride,
),
tokens: get_current_part(
&prev_encoding.tokens,
&o_tokens,
part_size,
part_id,
stride,
),
offsets: get_current_part(
&prev_encoding.offsets,
&o_offsets,
part_size,
part_id,
stride,
),
special_tokens_mask: get_current_part(
&prev_encoding.special_tokens_mask,
&o_spe_toks,
part_size,
part_id,
stride,
),
attention_mask: get_current_part(
&prev_encoding.attention_mask,
&o_attent,
part_size,
part_id,
stride,
),
overflowing: vec![],
};
part_id += 1;
overflowing.push(o);
prev_encoding = &overflowing.last().unwrap();
}
self.overflowing = Some(Box::new(Encoding {
normalized: trunc_normalized,
ids: o_ids,
type_ids: o_type_ids,
tokens: o_tokens,
offsets: o_offsets,
special_tokens_mask: o_spe_toks,
attention_mask: o_attent,
overflowing: None,
}));
self.overflowing = overflowing;
}
pub fn merge_with(&mut self, pair: Encoding) {
@@ -180,32 +221,27 @@ impl Encoding {
}
}
/// Prepend the `stride` last elements of the `previous` `Vec` to the current `Vec`.
// A new Vec is instantiated though.
fn prepend_stride<T: Clone>(previous: &[T], current: Vec<T>, stride: usize) -> Vec<T> {
let prev = previous
.iter()
.rev()
.take(stride)
.cloned()
.rev()
.collect::<Vec<_>>();
[&prev[..], &current[..]].concat()
#[inline]
fn get_current_part<T: Clone>(
prev: &[T],
current: &[T],
size: usize,
idx: usize,
stride: usize,
) -> Vec<T> {
let curr_slice = if (idx + 1) * size > current.len() {
&current[idx * size..]
} else {
&current[idx * size..(idx + 1) * size]
};
let prev_slice = &prev[prev.len() - stride..];
[prev_slice, curr_slice].concat()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_prepend_stride() {
let prev = vec![1, 2, 3, 4, 5, 6, 7, 8];
let curr = vec![9, 10, 11, 12];
assert_eq!(prepend_stride(&prev, curr, 3), vec![6, 7, 8, 9, 10, 11, 12]);
}
#[test]
fn merge_encodings() {
let mut a = Encoding {
@@ -216,7 +252,7 @@ mod tests {
offsets: vec![(0, 6)],
special_tokens_mask: vec![0],
attention_mask: vec![1],
overflowing: None,
overflowing: vec![],
};
let b = Encoding {
normalized: NormalizedString::from("World!"),
@@ -226,7 +262,7 @@ mod tests {
offsets: vec![(0, 6)],
special_tokens_mask: vec![0],
attention_mask: vec![1],
overflowing: None,
overflowing: vec![],
};
a.merge_with(b);
@@ -240,7 +276,7 @@ mod tests {
offsets: vec![(0, 6), (6, 12)],
special_tokens_mask: vec![0, 0],
attention_mask: vec![1, 1],
overflowing: None,
overflowing: vec![],
}
);
}
@@ -259,30 +295,30 @@ mod tests {
offsets: vec![(0, 5), (6, 11), (11, 12)],
special_tokens_mask: vec![0, 0, 0],
attention_mask: vec![1, 1, 1],
overflowing: None,
overflowing: vec![],
};
a.truncate(2, 0);
assert_eq!(
a,
Encoding {
normalized: NormalizedString::from("Hello World"),
normalized: NormalizedString::from("Hello World!"),
ids: vec![1, 2],
type_ids: vec![0, 0],
tokens: vec![String::from("Hello"), String::from("World")],
offsets: vec![(0, 5), (6, 11)],
special_tokens_mask: vec![0, 0],
attention_mask: vec![1, 1],
overflowing: Some(Box::new(Encoding {
normalized: NormalizedString::from("!"),
overflowing: vec![Encoding {
normalized: NormalizedString::from("Hello World!"),
ids: vec![3],
type_ids: vec![0],
tokens: vec![String::from("!")],
offsets: vec![(11, 12)],
special_tokens_mask: vec![0],
attention_mask: vec![1],
overflowing: None,
}))
overflowing: vec![],
}]
}
);
}

View File

@@ -285,7 +285,7 @@ impl Tokenizer {
vec![(0, sentence.len())],
vec![0],
vec![1],
None,
vec![],
));
}
@@ -321,7 +321,7 @@ impl Tokenizer {
offsets,
vec![0; length],
vec![1; length],
None,
vec![],
))
})
.collect::<Result<Vec<Encoding>>>()?;