mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Add truncation ability
This commit is contained in:
@ -7,3 +7,4 @@ pub mod normalizers;
|
||||
pub mod pre_tokenizers;
|
||||
pub mod processors;
|
||||
pub mod tokenizer;
|
||||
pub mod utils;
|
||||
|
@ -81,6 +81,9 @@ pub struct Encoding {
|
||||
type_ids: Vec<u32>,
|
||||
tokens: Vec<String>,
|
||||
offsets: Vec<(usize, usize)>,
|
||||
special_tokens_mask: Vec<u32>,
|
||||
attention_mask: Vec<u32>,
|
||||
overflowing: Option<Box<Encoding>>,
|
||||
}
|
||||
impl Encoding {
|
||||
pub fn new(
|
||||
@ -90,6 +93,9 @@ impl Encoding {
|
||||
type_ids: Vec<u32>,
|
||||
tokens: Vec<String>,
|
||||
offsets: Vec<(usize, usize)>,
|
||||
special_tokens_mask: Vec<u32>,
|
||||
attention_mask: Vec<u32>,
|
||||
overflowing: Option<Box<Encoding>>,
|
||||
) -> Self {
|
||||
Encoding {
|
||||
original,
|
||||
@ -98,6 +104,9 @@ impl Encoding {
|
||||
type_ids,
|
||||
tokens,
|
||||
offsets,
|
||||
special_tokens_mask,
|
||||
attention_mask,
|
||||
overflowing,
|
||||
}
|
||||
}
|
||||
|
||||
@ -124,6 +133,75 @@ impl Encoding {
|
||||
pub fn get_offsets(&self) -> &[(usize, usize)] {
|
||||
&self.offsets
|
||||
}
|
||||
|
||||
pub fn get_special_tokens_mask(&self) -> &[u32] {
|
||||
&self.special_tokens_mask
|
||||
}
|
||||
|
||||
pub fn get_attention_mask(&self) -> &[u32] {
|
||||
&self.attention_mask
|
||||
}
|
||||
|
||||
pub fn take_overflowing(&mut self) -> Option<Box<Encoding>> {
|
||||
self.overflowing.take()
|
||||
}
|
||||
|
||||
pub fn truncate(&mut self, max_len: usize, stride: usize) {
|
||||
if max_len > self.ids.len() {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut o_ids = self.ids.split_off(max_len);
|
||||
let mut o_type_ids = self.type_ids.split_off(max_len);
|
||||
let mut o_tokens = self.tokens.split_off(max_len);
|
||||
let mut o_offsets = self.offsets.split_off(max_len);
|
||||
let mut o_spe_toks = self.special_tokens_mask.split_off(max_len);
|
||||
let mut o_attent = self.attention_mask.split_off(max_len);
|
||||
|
||||
// Figure out offsets for original and normalized
|
||||
// TODO: We will be able to retrive the right part of original
|
||||
// only when we will have the alignment difference between both
|
||||
// For now we will use the normalized offset...
|
||||
let max = self
|
||||
.offsets
|
||||
.iter()
|
||||
.fold(0, |max, (_, end)| if *end > max { *end } else { max });
|
||||
let trunc_original = self.original.split_off(max);
|
||||
let trunc_normalized = self.normalized.split_off(max);
|
||||
|
||||
if stride > 0 {
|
||||
o_ids = prepend_stride(&self.ids, o_ids, stride);
|
||||
o_type_ids = prepend_stride(&self.type_ids, o_type_ids, stride);
|
||||
o_tokens = prepend_stride(&self.tokens, o_tokens, stride);
|
||||
o_offsets = prepend_stride(&self.offsets, o_offsets, stride);
|
||||
o_spe_toks = prepend_stride(&self.special_tokens_mask, o_spe_toks, stride);
|
||||
o_attent = prepend_stride(&self.attention_mask, o_attent, stride);
|
||||
}
|
||||
|
||||
self.overflowing = Some(Box::new(Encoding {
|
||||
original: trunc_original,
|
||||
normalized: trunc_normalized,
|
||||
ids: o_ids,
|
||||
type_ids: o_type_ids,
|
||||
tokens: o_tokens,
|
||||
offsets: o_offsets,
|
||||
special_tokens_mask: o_spe_toks,
|
||||
attention_mask: o_attent,
|
||||
overflowing: None,
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
fn prepend_stride<T: Clone>(previous: &Vec<T>, current: Vec<T>, stride: usize) -> Vec<T> {
|
||||
let prev = previous
|
||||
.iter()
|
||||
.rev()
|
||||
.take(stride)
|
||||
.map(|v| v.clone())
|
||||
.rev()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
[&prev[..], ¤t[..]].concat()
|
||||
}
|
||||
|
||||
pub enum EncodeInput {
|
||||
@ -238,6 +316,9 @@ impl Tokenizer {
|
||||
type_ids: vec![type_id; length],
|
||||
tokens,
|
||||
offsets,
|
||||
attention_mask: vec![1; length],
|
||||
special_tokens_mask: vec![0; length],
|
||||
overflowing: None,
|
||||
})
|
||||
};
|
||||
|
||||
@ -367,6 +448,14 @@ impl Tokenizer {
|
||||
.collect::<Vec<_>>(),
|
||||
]
|
||||
.concat(),
|
||||
special_tokens_mask: [
|
||||
&encoding.special_tokens_mask[..],
|
||||
&pair.special_tokens_mask[..],
|
||||
]
|
||||
.concat(),
|
||||
attention_mask: [&encoding.attention_mask[..], &pair.attention_mask[..]]
|
||||
.concat(),
|
||||
overflowing: None,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
79
tokenizers/src/utils.rs
Normal file
79
tokenizers/src/utils.rs
Normal file
@ -0,0 +1,79 @@
|
||||
use crate::tokenizer::{Encoding, Result};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum Error {
|
||||
SequenceTooSmall,
|
||||
SecondSequenceNotProvided,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Error {
|
||||
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
match self {
|
||||
Error::SequenceTooSmall => write!(fmt, "Truncation error: Sequence is too small"),
|
||||
Error::SecondSequenceNotProvided => {
|
||||
write!(fmt, "Truncation error: Second sequence not provided")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
impl std::error::Error for Error {}
|
||||
|
||||
#[derive(Clone, Copy, PartialEq)]
|
||||
pub enum TruncationStrategy {
|
||||
LongestFirst,
|
||||
OnlyFirst,
|
||||
OnlySecond,
|
||||
}
|
||||
|
||||
pub fn truncate_encodings(
|
||||
mut encoding: Encoding,
|
||||
mut pair_encoding: Option<Encoding>,
|
||||
to_remove: usize,
|
||||
strategy: TruncationStrategy,
|
||||
stride: usize,
|
||||
) -> Result<(Encoding, Option<Encoding>)> {
|
||||
if to_remove == 0 {
|
||||
return Ok((encoding, pair_encoding));
|
||||
}
|
||||
|
||||
match strategy {
|
||||
TruncationStrategy::LongestFirst => {
|
||||
let mut n_first = 0;
|
||||
let mut n_second = 0;
|
||||
|
||||
for _ in 0..to_remove {
|
||||
if pair_encoding.is_none()
|
||||
|| encoding.get_ids().len() > pair_encoding.as_ref().unwrap().get_ids().len()
|
||||
{
|
||||
n_first += 1;
|
||||
} else {
|
||||
n_second += 1;
|
||||
}
|
||||
}
|
||||
|
||||
encoding.truncate(encoding.get_ids().len() - n_first, stride);
|
||||
pair_encoding
|
||||
.as_mut()
|
||||
.map(|encoding| encoding.truncate(encoding.get_ids().len() - n_second, stride));
|
||||
}
|
||||
TruncationStrategy::OnlyFirst | TruncationStrategy::OnlySecond => {
|
||||
let target = if strategy == TruncationStrategy::OnlyFirst {
|
||||
Ok(&mut encoding)
|
||||
} else {
|
||||
if let Some(encoding) = pair_encoding.as_mut() {
|
||||
Ok(encoding)
|
||||
} else {
|
||||
Err(Box::new(Error::SecondSequenceNotProvided))
|
||||
}
|
||||
}?;
|
||||
|
||||
if target.get_ids().len() <= to_remove {
|
||||
return Err(Box::new(Error::SequenceTooSmall));
|
||||
}
|
||||
|
||||
target.truncate(target.get_ids().len() - to_remove, stride);
|
||||
}
|
||||
}
|
||||
|
||||
Ok((encoding, pair_encoding))
|
||||
}
|
Reference in New Issue
Block a user