Add truncation ability

This commit is contained in:
Anthony MOI
2019-12-12 17:19:31 -05:00
parent 13df36ca55
commit f4cd78e98a
3 changed files with 169 additions and 0 deletions

View File

@ -7,3 +7,4 @@ pub mod normalizers;
pub mod pre_tokenizers;
pub mod processors;
pub mod tokenizer;
pub mod utils;

View File

@ -81,6 +81,9 @@ pub struct Encoding {
type_ids: Vec<u32>,
tokens: Vec<String>,
offsets: Vec<(usize, usize)>,
special_tokens_mask: Vec<u32>,
attention_mask: Vec<u32>,
overflowing: Option<Box<Encoding>>,
}
impl Encoding {
pub fn new(
@ -90,6 +93,9 @@ impl Encoding {
type_ids: Vec<u32>,
tokens: Vec<String>,
offsets: Vec<(usize, usize)>,
special_tokens_mask: Vec<u32>,
attention_mask: Vec<u32>,
overflowing: Option<Box<Encoding>>,
) -> Self {
Encoding {
original,
@ -98,6 +104,9 @@ impl Encoding {
type_ids,
tokens,
offsets,
special_tokens_mask,
attention_mask,
overflowing,
}
}
@ -124,6 +133,75 @@ impl Encoding {
pub fn get_offsets(&self) -> &[(usize, usize)] {
&self.offsets
}
pub fn get_special_tokens_mask(&self) -> &[u32] {
&self.special_tokens_mask
}
pub fn get_attention_mask(&self) -> &[u32] {
&self.attention_mask
}
pub fn take_overflowing(&mut self) -> Option<Box<Encoding>> {
self.overflowing.take()
}
pub fn truncate(&mut self, max_len: usize, stride: usize) {
if max_len > self.ids.len() {
return;
}
let mut o_ids = self.ids.split_off(max_len);
let mut o_type_ids = self.type_ids.split_off(max_len);
let mut o_tokens = self.tokens.split_off(max_len);
let mut o_offsets = self.offsets.split_off(max_len);
let mut o_spe_toks = self.special_tokens_mask.split_off(max_len);
let mut o_attent = self.attention_mask.split_off(max_len);
// Figure out offsets for original and normalized
// TODO: We will be able to retrive the right part of original
// only when we will have the alignment difference between both
// For now we will use the normalized offset...
let max = self
.offsets
.iter()
.fold(0, |max, (_, end)| if *end > max { *end } else { max });
let trunc_original = self.original.split_off(max);
let trunc_normalized = self.normalized.split_off(max);
if stride > 0 {
o_ids = prepend_stride(&self.ids, o_ids, stride);
o_type_ids = prepend_stride(&self.type_ids, o_type_ids, stride);
o_tokens = prepend_stride(&self.tokens, o_tokens, stride);
o_offsets = prepend_stride(&self.offsets, o_offsets, stride);
o_spe_toks = prepend_stride(&self.special_tokens_mask, o_spe_toks, stride);
o_attent = prepend_stride(&self.attention_mask, o_attent, stride);
}
self.overflowing = Some(Box::new(Encoding {
original: trunc_original,
normalized: trunc_normalized,
ids: o_ids,
type_ids: o_type_ids,
tokens: o_tokens,
offsets: o_offsets,
special_tokens_mask: o_spe_toks,
attention_mask: o_attent,
overflowing: None,
}));
}
}
fn prepend_stride<T: Clone>(previous: &Vec<T>, current: Vec<T>, stride: usize) -> Vec<T> {
let prev = previous
.iter()
.rev()
.take(stride)
.map(|v| v.clone())
.rev()
.collect::<Vec<_>>();
[&prev[..], &current[..]].concat()
}
pub enum EncodeInput {
@ -238,6 +316,9 @@ impl Tokenizer {
type_ids: vec![type_id; length],
tokens,
offsets,
attention_mask: vec![1; length],
special_tokens_mask: vec![0; length],
overflowing: None,
})
};
@ -367,6 +448,14 @@ impl Tokenizer {
.collect::<Vec<_>>(),
]
.concat(),
special_tokens_mask: [
&encoding.special_tokens_mask[..],
&pair.special_tokens_mask[..],
]
.concat(),
attention_mask: [&encoding.attention_mask[..], &pair.attention_mask[..]]
.concat(),
overflowing: None,
}),
}
}

79
tokenizers/src/utils.rs Normal file
View File

@ -0,0 +1,79 @@
use crate::tokenizer::{Encoding, Result};
#[derive(Debug)]
pub enum Error {
SequenceTooSmall,
SecondSequenceNotProvided,
}
impl std::fmt::Display for Error {
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
Error::SequenceTooSmall => write!(fmt, "Truncation error: Sequence is too small"),
Error::SecondSequenceNotProvided => {
write!(fmt, "Truncation error: Second sequence not provided")
}
}
}
}
impl std::error::Error for Error {}
#[derive(Clone, Copy, PartialEq)]
pub enum TruncationStrategy {
LongestFirst,
OnlyFirst,
OnlySecond,
}
pub fn truncate_encodings(
mut encoding: Encoding,
mut pair_encoding: Option<Encoding>,
to_remove: usize,
strategy: TruncationStrategy,
stride: usize,
) -> Result<(Encoding, Option<Encoding>)> {
if to_remove == 0 {
return Ok((encoding, pair_encoding));
}
match strategy {
TruncationStrategy::LongestFirst => {
let mut n_first = 0;
let mut n_second = 0;
for _ in 0..to_remove {
if pair_encoding.is_none()
|| encoding.get_ids().len() > pair_encoding.as_ref().unwrap().get_ids().len()
{
n_first += 1;
} else {
n_second += 1;
}
}
encoding.truncate(encoding.get_ids().len() - n_first, stride);
pair_encoding
.as_mut()
.map(|encoding| encoding.truncate(encoding.get_ids().len() - n_second, stride));
}
TruncationStrategy::OnlyFirst | TruncationStrategy::OnlySecond => {
let target = if strategy == TruncationStrategy::OnlyFirst {
Ok(&mut encoding)
} else {
if let Some(encoding) = pair_encoding.as_mut() {
Ok(encoding)
} else {
Err(Box::new(Error::SecondSequenceNotProvided))
}
}?;
if target.get_ids().len() <= to_remove {
return Err(Box::new(Error::SequenceTooSmall));
}
target.truncate(target.get_ids().len() - to_remove, stride);
}
}
Ok((encoding, pair_encoding))
}