mirror of
https://github.com/mii443/tokenizers.git
synced 2025-12-13 22:08:48 +00:00
Improve the truncation of an Encoding
This commit is contained in:
committed by
Pierric Cistac
parent
78e26905a7
commit
68f99bb822
@@ -17,7 +17,7 @@ pub struct Encoding {
|
|||||||
offsets: Vec<(usize, usize)>,
|
offsets: Vec<(usize, usize)>,
|
||||||
special_tokens_mask: Vec<u32>,
|
special_tokens_mask: Vec<u32>,
|
||||||
attention_mask: Vec<u32>,
|
attention_mask: Vec<u32>,
|
||||||
overflowing: Option<Box<Encoding>>,
|
overflowing: Vec<Encoding>,
|
||||||
}
|
}
|
||||||
impl Encoding {
|
impl Encoding {
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
@@ -29,7 +29,7 @@ impl Encoding {
|
|||||||
offsets: Vec<(usize, usize)>,
|
offsets: Vec<(usize, usize)>,
|
||||||
special_tokens_mask: Vec<u32>,
|
special_tokens_mask: Vec<u32>,
|
||||||
attention_mask: Vec<u32>,
|
attention_mask: Vec<u32>,
|
||||||
overflowing: Option<Box<Encoding>>,
|
overflowing: Vec<Encoding>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Encoding {
|
Encoding {
|
||||||
normalized,
|
normalized,
|
||||||
@@ -71,48 +71,89 @@ impl Encoding {
|
|||||||
&self.attention_mask
|
&self.attention_mask
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_overflowing(&self) -> Option<&Encoding> {
|
pub fn get_overflowing(&self) -> &Vec<Encoding> {
|
||||||
self.overflowing.as_ref().map(|b| &**b)
|
&self.overflowing
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn take_overflowing(&mut self) -> Option<Box<Encoding>> {
|
pub fn take_overflowing(&mut self) -> Vec<Encoding> {
|
||||||
self.overflowing.take()
|
std::mem::replace(&mut self.overflowing, vec![])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Truncate the current `Encoding`.
|
||||||
|
///
|
||||||
|
/// Panic if `stride >= max_len`.
|
||||||
pub fn truncate(&mut self, max_len: usize, stride: usize) {
|
pub fn truncate(&mut self, max_len: usize, stride: usize) {
|
||||||
if max_len >= self.ids.len() {
|
if max_len >= self.ids.len() {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut o_ids = self.ids.split_off(max_len);
|
// Get the main overflowing part
|
||||||
let mut o_type_ids = self.type_ids.split_off(max_len);
|
let o_ids = self.ids.split_off(max_len);
|
||||||
let mut o_tokens = self.tokens.split_off(max_len);
|
let o_type_ids = self.type_ids.split_off(max_len);
|
||||||
let mut o_offsets = self.offsets.split_off(max_len);
|
let o_tokens = self.tokens.split_off(max_len);
|
||||||
let mut o_spe_toks = self.special_tokens_mask.split_off(max_len);
|
let o_offsets = self.offsets.split_off(max_len);
|
||||||
let mut o_attent = self.attention_mask.split_off(max_len);
|
let o_spe_toks = self.special_tokens_mask.split_off(max_len);
|
||||||
|
let o_attent = self.attention_mask.split_off(max_len);
|
||||||
|
|
||||||
let max = self.offsets.last().map(|(_, end)| *end).unwrap_or(0);
|
// Now we need to separate each overflowing part into as many Encoding as needed
|
||||||
let trunc_normalized = self.normalized.split_off(max);
|
assert!(stride < max_len);
|
||||||
|
let part_size = max_len - stride;
|
||||||
|
let mut overflowing = vec![];
|
||||||
|
let mut part_id = 0;
|
||||||
|
let mut prev_encoding: &Encoding = self;
|
||||||
|
|
||||||
if stride > 0 {
|
loop {
|
||||||
o_ids = prepend_stride(&self.ids, o_ids, stride);
|
if part_size * part_id >= o_ids.len() {
|
||||||
o_type_ids = prepend_stride(&self.type_ids, o_type_ids, stride);
|
break;
|
||||||
o_tokens = prepend_stride(&self.tokens, o_tokens, stride);
|
}
|
||||||
o_offsets = prepend_stride(&self.offsets, o_offsets, stride);
|
|
||||||
o_spe_toks = prepend_stride(&self.special_tokens_mask, o_spe_toks, stride);
|
let o = Encoding {
|
||||||
o_attent = prepend_stride(&self.attention_mask, o_attent, stride);
|
normalized: self.normalized.clone(),
|
||||||
|
ids: get_current_part(&prev_encoding.ids, &o_ids, part_size, part_id, stride),
|
||||||
|
type_ids: get_current_part(
|
||||||
|
&prev_encoding.type_ids,
|
||||||
|
&o_type_ids,
|
||||||
|
part_size,
|
||||||
|
part_id,
|
||||||
|
stride,
|
||||||
|
),
|
||||||
|
tokens: get_current_part(
|
||||||
|
&prev_encoding.tokens,
|
||||||
|
&o_tokens,
|
||||||
|
part_size,
|
||||||
|
part_id,
|
||||||
|
stride,
|
||||||
|
),
|
||||||
|
offsets: get_current_part(
|
||||||
|
&prev_encoding.offsets,
|
||||||
|
&o_offsets,
|
||||||
|
part_size,
|
||||||
|
part_id,
|
||||||
|
stride,
|
||||||
|
),
|
||||||
|
special_tokens_mask: get_current_part(
|
||||||
|
&prev_encoding.special_tokens_mask,
|
||||||
|
&o_spe_toks,
|
||||||
|
part_size,
|
||||||
|
part_id,
|
||||||
|
stride,
|
||||||
|
),
|
||||||
|
attention_mask: get_current_part(
|
||||||
|
&prev_encoding.attention_mask,
|
||||||
|
&o_attent,
|
||||||
|
part_size,
|
||||||
|
part_id,
|
||||||
|
stride,
|
||||||
|
),
|
||||||
|
overflowing: vec![],
|
||||||
|
};
|
||||||
|
|
||||||
|
part_id += 1;
|
||||||
|
overflowing.push(o);
|
||||||
|
prev_encoding = &overflowing.last().unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
self.overflowing = Some(Box::new(Encoding {
|
self.overflowing = overflowing;
|
||||||
normalized: trunc_normalized,
|
|
||||||
ids: o_ids,
|
|
||||||
type_ids: o_type_ids,
|
|
||||||
tokens: o_tokens,
|
|
||||||
offsets: o_offsets,
|
|
||||||
special_tokens_mask: o_spe_toks,
|
|
||||||
attention_mask: o_attent,
|
|
||||||
overflowing: None,
|
|
||||||
}));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn merge_with(&mut self, pair: Encoding) {
|
pub fn merge_with(&mut self, pair: Encoding) {
|
||||||
@@ -180,32 +221,27 @@ impl Encoding {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Prepend the `stride` last elements of the `previous` `Vec` to the current `Vec`.
|
#[inline]
|
||||||
// A new Vec is instantiated though.
|
fn get_current_part<T: Clone>(
|
||||||
fn prepend_stride<T: Clone>(previous: &[T], current: Vec<T>, stride: usize) -> Vec<T> {
|
prev: &[T],
|
||||||
let prev = previous
|
current: &[T],
|
||||||
.iter()
|
size: usize,
|
||||||
.rev()
|
idx: usize,
|
||||||
.take(stride)
|
stride: usize,
|
||||||
.cloned()
|
) -> Vec<T> {
|
||||||
.rev()
|
let curr_slice = if (idx + 1) * size > current.len() {
|
||||||
.collect::<Vec<_>>();
|
¤t[idx * size..]
|
||||||
|
} else {
|
||||||
[&prev[..], ¤t[..]].concat()
|
¤t[idx * size..(idx + 1) * size]
|
||||||
|
};
|
||||||
|
let prev_slice = &prev[prev.len() - stride..];
|
||||||
|
[prev_slice, curr_slice].concat()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_prepend_stride() {
|
|
||||||
let prev = vec![1, 2, 3, 4, 5, 6, 7, 8];
|
|
||||||
let curr = vec![9, 10, 11, 12];
|
|
||||||
|
|
||||||
assert_eq!(prepend_stride(&prev, curr, 3), vec![6, 7, 8, 9, 10, 11, 12]);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn merge_encodings() {
|
fn merge_encodings() {
|
||||||
let mut a = Encoding {
|
let mut a = Encoding {
|
||||||
@@ -216,7 +252,7 @@ mod tests {
|
|||||||
offsets: vec![(0, 6)],
|
offsets: vec![(0, 6)],
|
||||||
special_tokens_mask: vec![0],
|
special_tokens_mask: vec![0],
|
||||||
attention_mask: vec![1],
|
attention_mask: vec![1],
|
||||||
overflowing: None,
|
overflowing: vec![],
|
||||||
};
|
};
|
||||||
let b = Encoding {
|
let b = Encoding {
|
||||||
normalized: NormalizedString::from("World!"),
|
normalized: NormalizedString::from("World!"),
|
||||||
@@ -226,7 +262,7 @@ mod tests {
|
|||||||
offsets: vec![(0, 6)],
|
offsets: vec![(0, 6)],
|
||||||
special_tokens_mask: vec![0],
|
special_tokens_mask: vec![0],
|
||||||
attention_mask: vec![1],
|
attention_mask: vec![1],
|
||||||
overflowing: None,
|
overflowing: vec![],
|
||||||
};
|
};
|
||||||
a.merge_with(b);
|
a.merge_with(b);
|
||||||
|
|
||||||
@@ -240,7 +276,7 @@ mod tests {
|
|||||||
offsets: vec![(0, 6), (6, 12)],
|
offsets: vec![(0, 6), (6, 12)],
|
||||||
special_tokens_mask: vec![0, 0],
|
special_tokens_mask: vec![0, 0],
|
||||||
attention_mask: vec![1, 1],
|
attention_mask: vec![1, 1],
|
||||||
overflowing: None,
|
overflowing: vec![],
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@@ -259,30 +295,30 @@ mod tests {
|
|||||||
offsets: vec![(0, 5), (6, 11), (11, 12)],
|
offsets: vec![(0, 5), (6, 11), (11, 12)],
|
||||||
special_tokens_mask: vec![0, 0, 0],
|
special_tokens_mask: vec![0, 0, 0],
|
||||||
attention_mask: vec![1, 1, 1],
|
attention_mask: vec![1, 1, 1],
|
||||||
overflowing: None,
|
overflowing: vec![],
|
||||||
};
|
};
|
||||||
a.truncate(2, 0);
|
a.truncate(2, 0);
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
a,
|
a,
|
||||||
Encoding {
|
Encoding {
|
||||||
normalized: NormalizedString::from("Hello World"),
|
normalized: NormalizedString::from("Hello World!"),
|
||||||
ids: vec![1, 2],
|
ids: vec![1, 2],
|
||||||
type_ids: vec![0, 0],
|
type_ids: vec![0, 0],
|
||||||
tokens: vec![String::from("Hello"), String::from("World")],
|
tokens: vec![String::from("Hello"), String::from("World")],
|
||||||
offsets: vec![(0, 5), (6, 11)],
|
offsets: vec![(0, 5), (6, 11)],
|
||||||
special_tokens_mask: vec![0, 0],
|
special_tokens_mask: vec![0, 0],
|
||||||
attention_mask: vec![1, 1],
|
attention_mask: vec![1, 1],
|
||||||
overflowing: Some(Box::new(Encoding {
|
overflowing: vec![Encoding {
|
||||||
normalized: NormalizedString::from("!"),
|
normalized: NormalizedString::from("Hello World!"),
|
||||||
ids: vec![3],
|
ids: vec![3],
|
||||||
type_ids: vec![0],
|
type_ids: vec![0],
|
||||||
tokens: vec![String::from("!")],
|
tokens: vec![String::from("!")],
|
||||||
offsets: vec![(11, 12)],
|
offsets: vec![(11, 12)],
|
||||||
special_tokens_mask: vec![0],
|
special_tokens_mask: vec![0],
|
||||||
attention_mask: vec![1],
|
attention_mask: vec![1],
|
||||||
overflowing: None,
|
overflowing: vec![],
|
||||||
}))
|
}]
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -285,7 +285,7 @@ impl Tokenizer {
|
|||||||
vec![(0, sentence.len())],
|
vec![(0, sentence.len())],
|
||||||
vec![0],
|
vec![0],
|
||||||
vec![1],
|
vec![1],
|
||||||
None,
|
vec![],
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -321,7 +321,7 @@ impl Tokenizer {
|
|||||||
offsets,
|
offsets,
|
||||||
vec![0; length],
|
vec![0; length],
|
||||||
vec![1; length],
|
vec![1; length],
|
||||||
None,
|
vec![],
|
||||||
))
|
))
|
||||||
})
|
})
|
||||||
.collect::<Result<Vec<Encoding>>>()?;
|
.collect::<Result<Vec<Encoding>>>()?;
|
||||||
|
|||||||
Reference in New Issue
Block a user