mirror of
https://github.com/mii443/tokenizers.git
synced 2025-12-12 21:38:35 +00:00
Improve the truncation of an Encoding
This commit is contained in:
committed by
Pierric Cistac
parent
78e26905a7
commit
68f99bb822
@@ -17,7 +17,7 @@ pub struct Encoding {
|
||||
offsets: Vec<(usize, usize)>,
|
||||
special_tokens_mask: Vec<u32>,
|
||||
attention_mask: Vec<u32>,
|
||||
overflowing: Option<Box<Encoding>>,
|
||||
overflowing: Vec<Encoding>,
|
||||
}
|
||||
impl Encoding {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
@@ -29,7 +29,7 @@ impl Encoding {
|
||||
offsets: Vec<(usize, usize)>,
|
||||
special_tokens_mask: Vec<u32>,
|
||||
attention_mask: Vec<u32>,
|
||||
overflowing: Option<Box<Encoding>>,
|
||||
overflowing: Vec<Encoding>,
|
||||
) -> Self {
|
||||
Encoding {
|
||||
normalized,
|
||||
@@ -71,48 +71,89 @@ impl Encoding {
|
||||
&self.attention_mask
|
||||
}
|
||||
|
||||
pub fn get_overflowing(&self) -> Option<&Encoding> {
|
||||
self.overflowing.as_ref().map(|b| &**b)
|
||||
pub fn get_overflowing(&self) -> &Vec<Encoding> {
|
||||
&self.overflowing
|
||||
}
|
||||
|
||||
pub fn take_overflowing(&mut self) -> Option<Box<Encoding>> {
|
||||
self.overflowing.take()
|
||||
pub fn take_overflowing(&mut self) -> Vec<Encoding> {
|
||||
std::mem::replace(&mut self.overflowing, vec![])
|
||||
}
|
||||
|
||||
/// Truncate the current `Encoding`.
|
||||
///
|
||||
/// Panic if `stride >= max_len`.
|
||||
pub fn truncate(&mut self, max_len: usize, stride: usize) {
|
||||
if max_len >= self.ids.len() {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut o_ids = self.ids.split_off(max_len);
|
||||
let mut o_type_ids = self.type_ids.split_off(max_len);
|
||||
let mut o_tokens = self.tokens.split_off(max_len);
|
||||
let mut o_offsets = self.offsets.split_off(max_len);
|
||||
let mut o_spe_toks = self.special_tokens_mask.split_off(max_len);
|
||||
let mut o_attent = self.attention_mask.split_off(max_len);
|
||||
// Get the main overflowing part
|
||||
let o_ids = self.ids.split_off(max_len);
|
||||
let o_type_ids = self.type_ids.split_off(max_len);
|
||||
let o_tokens = self.tokens.split_off(max_len);
|
||||
let o_offsets = self.offsets.split_off(max_len);
|
||||
let o_spe_toks = self.special_tokens_mask.split_off(max_len);
|
||||
let o_attent = self.attention_mask.split_off(max_len);
|
||||
|
||||
let max = self.offsets.last().map(|(_, end)| *end).unwrap_or(0);
|
||||
let trunc_normalized = self.normalized.split_off(max);
|
||||
// Now we need to separate each overflowing part into as many Encoding as needed
|
||||
assert!(stride < max_len);
|
||||
let part_size = max_len - stride;
|
||||
let mut overflowing = vec![];
|
||||
let mut part_id = 0;
|
||||
let mut prev_encoding: &Encoding = self;
|
||||
|
||||
if stride > 0 {
|
||||
o_ids = prepend_stride(&self.ids, o_ids, stride);
|
||||
o_type_ids = prepend_stride(&self.type_ids, o_type_ids, stride);
|
||||
o_tokens = prepend_stride(&self.tokens, o_tokens, stride);
|
||||
o_offsets = prepend_stride(&self.offsets, o_offsets, stride);
|
||||
o_spe_toks = prepend_stride(&self.special_tokens_mask, o_spe_toks, stride);
|
||||
o_attent = prepend_stride(&self.attention_mask, o_attent, stride);
|
||||
loop {
|
||||
if part_size * part_id >= o_ids.len() {
|
||||
break;
|
||||
}
|
||||
|
||||
let o = Encoding {
|
||||
normalized: self.normalized.clone(),
|
||||
ids: get_current_part(&prev_encoding.ids, &o_ids, part_size, part_id, stride),
|
||||
type_ids: get_current_part(
|
||||
&prev_encoding.type_ids,
|
||||
&o_type_ids,
|
||||
part_size,
|
||||
part_id,
|
||||
stride,
|
||||
),
|
||||
tokens: get_current_part(
|
||||
&prev_encoding.tokens,
|
||||
&o_tokens,
|
||||
part_size,
|
||||
part_id,
|
||||
stride,
|
||||
),
|
||||
offsets: get_current_part(
|
||||
&prev_encoding.offsets,
|
||||
&o_offsets,
|
||||
part_size,
|
||||
part_id,
|
||||
stride,
|
||||
),
|
||||
special_tokens_mask: get_current_part(
|
||||
&prev_encoding.special_tokens_mask,
|
||||
&o_spe_toks,
|
||||
part_size,
|
||||
part_id,
|
||||
stride,
|
||||
),
|
||||
attention_mask: get_current_part(
|
||||
&prev_encoding.attention_mask,
|
||||
&o_attent,
|
||||
part_size,
|
||||
part_id,
|
||||
stride,
|
||||
),
|
||||
overflowing: vec![],
|
||||
};
|
||||
|
||||
part_id += 1;
|
||||
overflowing.push(o);
|
||||
prev_encoding = &overflowing.last().unwrap();
|
||||
}
|
||||
|
||||
self.overflowing = Some(Box::new(Encoding {
|
||||
normalized: trunc_normalized,
|
||||
ids: o_ids,
|
||||
type_ids: o_type_ids,
|
||||
tokens: o_tokens,
|
||||
offsets: o_offsets,
|
||||
special_tokens_mask: o_spe_toks,
|
||||
attention_mask: o_attent,
|
||||
overflowing: None,
|
||||
}));
|
||||
self.overflowing = overflowing;
|
||||
}
|
||||
|
||||
pub fn merge_with(&mut self, pair: Encoding) {
|
||||
@@ -180,32 +221,27 @@ impl Encoding {
|
||||
}
|
||||
}
|
||||
|
||||
/// Prepend the `stride` last elements of the `previous` `Vec` to the current `Vec`.
|
||||
// A new Vec is instantiated though.
|
||||
fn prepend_stride<T: Clone>(previous: &[T], current: Vec<T>, stride: usize) -> Vec<T> {
|
||||
let prev = previous
|
||||
.iter()
|
||||
.rev()
|
||||
.take(stride)
|
||||
.cloned()
|
||||
.rev()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
[&prev[..], ¤t[..]].concat()
|
||||
#[inline]
|
||||
fn get_current_part<T: Clone>(
|
||||
prev: &[T],
|
||||
current: &[T],
|
||||
size: usize,
|
||||
idx: usize,
|
||||
stride: usize,
|
||||
) -> Vec<T> {
|
||||
let curr_slice = if (idx + 1) * size > current.len() {
|
||||
¤t[idx * size..]
|
||||
} else {
|
||||
¤t[idx * size..(idx + 1) * size]
|
||||
};
|
||||
let prev_slice = &prev[prev.len() - stride..];
|
||||
[prev_slice, curr_slice].concat()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_prepend_stride() {
|
||||
let prev = vec![1, 2, 3, 4, 5, 6, 7, 8];
|
||||
let curr = vec![9, 10, 11, 12];
|
||||
|
||||
assert_eq!(prepend_stride(&prev, curr, 3), vec![6, 7, 8, 9, 10, 11, 12]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn merge_encodings() {
|
||||
let mut a = Encoding {
|
||||
@@ -216,7 +252,7 @@ mod tests {
|
||||
offsets: vec![(0, 6)],
|
||||
special_tokens_mask: vec![0],
|
||||
attention_mask: vec![1],
|
||||
overflowing: None,
|
||||
overflowing: vec![],
|
||||
};
|
||||
let b = Encoding {
|
||||
normalized: NormalizedString::from("World!"),
|
||||
@@ -226,7 +262,7 @@ mod tests {
|
||||
offsets: vec![(0, 6)],
|
||||
special_tokens_mask: vec![0],
|
||||
attention_mask: vec![1],
|
||||
overflowing: None,
|
||||
overflowing: vec![],
|
||||
};
|
||||
a.merge_with(b);
|
||||
|
||||
@@ -240,7 +276,7 @@ mod tests {
|
||||
offsets: vec![(0, 6), (6, 12)],
|
||||
special_tokens_mask: vec![0, 0],
|
||||
attention_mask: vec![1, 1],
|
||||
overflowing: None,
|
||||
overflowing: vec![],
|
||||
}
|
||||
);
|
||||
}
|
||||
@@ -259,30 +295,30 @@ mod tests {
|
||||
offsets: vec![(0, 5), (6, 11), (11, 12)],
|
||||
special_tokens_mask: vec![0, 0, 0],
|
||||
attention_mask: vec![1, 1, 1],
|
||||
overflowing: None,
|
||||
overflowing: vec![],
|
||||
};
|
||||
a.truncate(2, 0);
|
||||
|
||||
assert_eq!(
|
||||
a,
|
||||
Encoding {
|
||||
normalized: NormalizedString::from("Hello World"),
|
||||
normalized: NormalizedString::from("Hello World!"),
|
||||
ids: vec![1, 2],
|
||||
type_ids: vec![0, 0],
|
||||
tokens: vec![String::from("Hello"), String::from("World")],
|
||||
offsets: vec![(0, 5), (6, 11)],
|
||||
special_tokens_mask: vec![0, 0],
|
||||
attention_mask: vec![1, 1],
|
||||
overflowing: Some(Box::new(Encoding {
|
||||
normalized: NormalizedString::from("!"),
|
||||
overflowing: vec![Encoding {
|
||||
normalized: NormalizedString::from("Hello World!"),
|
||||
ids: vec![3],
|
||||
type_ids: vec![0],
|
||||
tokens: vec![String::from("!")],
|
||||
offsets: vec![(11, 12)],
|
||||
special_tokens_mask: vec![0],
|
||||
attention_mask: vec![1],
|
||||
overflowing: None,
|
||||
}))
|
||||
overflowing: vec![],
|
||||
}]
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
@@ -285,7 +285,7 @@ impl Tokenizer {
|
||||
vec![(0, sentence.len())],
|
||||
vec![0],
|
||||
vec![1],
|
||||
None,
|
||||
vec![],
|
||||
));
|
||||
}
|
||||
|
||||
@@ -321,7 +321,7 @@ impl Tokenizer {
|
||||
offsets,
|
||||
vec![0; length],
|
||||
vec![1; length],
|
||||
None,
|
||||
vec![],
|
||||
))
|
||||
})
|
||||
.collect::<Result<Vec<Encoding>>>()?;
|
||||
|
||||
Reference in New Issue
Block a user