mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-23 16:49:27 +00:00
Tokenizer handles Truncation and Padding
This commit is contained in:
@ -1,56 +1,30 @@
|
||||
use crate::tokenizer::{Encoding, PostProcessor, Result};
|
||||
use crate::utils::{truncate_encodings, TruncationStrategy};
|
||||
|
||||
pub struct BertProcessing {
|
||||
max_len: usize,
|
||||
trunc_strategy: TruncationStrategy,
|
||||
trunc_stride: usize,
|
||||
|
||||
sep: (String, u32),
|
||||
cls: (String, u32),
|
||||
}
|
||||
|
||||
impl BertProcessing {
|
||||
pub fn new(
|
||||
max_len: usize,
|
||||
trunc_strategy: TruncationStrategy,
|
||||
trunc_stride: usize,
|
||||
sep: (String, u32),
|
||||
cls: (String, u32),
|
||||
) -> Self {
|
||||
BertProcessing {
|
||||
max_len,
|
||||
trunc_strategy,
|
||||
trunc_stride,
|
||||
sep,
|
||||
cls,
|
||||
}
|
||||
pub fn new(sep: (String, u32), cls: (String, u32)) -> Self {
|
||||
BertProcessing { sep, cls }
|
||||
}
|
||||
}
|
||||
|
||||
impl PostProcessor for BertProcessing {
|
||||
fn process(&self, encoding: Encoding, pair_encoding: Option<Encoding>) -> Result<Encoding> {
|
||||
let special_token_len = if pair_encoding.is_some() { 3 } else { 2 };
|
||||
let total_len = encoding.get_ids().len()
|
||||
+ pair_encoding
|
||||
.as_ref()
|
||||
.map(|e| e.get_ids().len())
|
||||
.unwrap_or(0)
|
||||
+ special_token_len;
|
||||
|
||||
let need_trunc = if total_len > self.max_len {
|
||||
total_len - self.max_len
|
||||
fn added_tokens(
|
||||
&self,
|
||||
_encoding: &Encoding,
|
||||
pair_encoding: &Option<Encoding>,
|
||||
) -> Result<usize> {
|
||||
if pair_encoding.is_some() {
|
||||
Ok(3)
|
||||
} else {
|
||||
0
|
||||
};
|
||||
let (mut encoding, pair_encoding) = truncate_encodings(
|
||||
encoding,
|
||||
pair_encoding,
|
||||
need_trunc,
|
||||
self.trunc_strategy,
|
||||
self.trunc_stride,
|
||||
)?;
|
||||
Ok(2)
|
||||
}
|
||||
}
|
||||
|
||||
fn process(&self, mut encoding: Encoding, pair_encoding: Option<Encoding>) -> Result<Encoding> {
|
||||
// Prepare ids
|
||||
let ids = [&[self.cls.1], &encoding.get_ids()[..], &[self.sep.1]].concat();
|
||||
let pair_ids = pair_encoding
|
||||
|
@ -1,4 +1,5 @@
|
||||
/// The various possible padding directions
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum PaddingDirection {
|
||||
Left,
|
||||
Right,
|
||||
@ -155,8 +156,8 @@ impl Encoding {
|
||||
pad_length: usize,
|
||||
pad_id: u32,
|
||||
pad_type_id: u32,
|
||||
pad_token: String,
|
||||
direction: PaddingDirection,
|
||||
pad_token: &str,
|
||||
direction: &PaddingDirection,
|
||||
) {
|
||||
if self.ids.len() > pad_length {
|
||||
// We just do nothing if the wanted padding length is smaller than us
|
||||
@ -166,13 +167,13 @@ impl Encoding {
|
||||
|
||||
let ids_pad = vec![pad_id; pad_length];
|
||||
let type_ids_pad = vec![pad_type_id; pad_length];
|
||||
let tokens_pad = vec![pad_token; pad_length];
|
||||
let tokens_pad = vec![pad_token.to_owned(); pad_length];
|
||||
let attention_pad = vec![0; pad_length];
|
||||
let special_pad = vec![1; pad_length];
|
||||
let offsets_pad = vec![(0, 0); pad_length];
|
||||
|
||||
match direction {
|
||||
Left => {
|
||||
PaddingDirection::Left => {
|
||||
self.ids = [&ids_pad[..], &self.ids[..]].concat();
|
||||
self.type_ids = [&type_ids_pad[..], &self.type_ids[..]].concat();
|
||||
self.tokens = [&tokens_pad[..], &self.tokens[..]].concat();
|
||||
@ -181,7 +182,7 @@ impl Encoding {
|
||||
[&special_pad[..], &self.special_tokens_mask[..]].concat();
|
||||
self.offsets = [&offsets_pad[..], &self.offsets[..]].concat();
|
||||
}
|
||||
Right => {
|
||||
PaddingDirection::Right => {
|
||||
self.ids = [&self.ids[..], &ids_pad[..]].concat();
|
||||
self.type_ids = [&self.type_ids[..], &type_ids_pad[..]].concat();
|
||||
self.tokens = [&self.tokens[..], &tokens_pad[..]].concat();
|
||||
|
@ -12,6 +12,9 @@
|
||||
//! - PostProcessor: Takes care of the processing after tokenization. (Like truncating, padding,
|
||||
//! ...)
|
||||
//!
|
||||
use crate::utils::{
|
||||
pad_encodings, truncate_encodings, PaddingParams, PaddingStrategy, TruncationParams,
|
||||
};
|
||||
use rayon::prelude::*;
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
@ -20,7 +23,7 @@ use std::{
|
||||
};
|
||||
|
||||
mod encoding;
|
||||
pub use encoding::Encoding;
|
||||
pub use encoding::*;
|
||||
|
||||
pub type Result<T> = std::result::Result<T, Box<dyn std::error::Error + Send + Sync>>;
|
||||
|
||||
@ -46,6 +49,7 @@ pub trait Model {
|
||||
/// A PostProcessor has the responsibility to post process an encoded output of the Tokenizer.
|
||||
/// Truncating, Padding, etc... are PostProcessor steps
|
||||
pub trait PostProcessor {
|
||||
fn added_tokens(&self, encoding: &Encoding, pair_encoding: &Option<Encoding>) -> Result<usize>;
|
||||
fn process(&self, encoding: Encoding, pair_encoding: Option<Encoding>) -> Result<Encoding>;
|
||||
}
|
||||
|
||||
@ -125,6 +129,10 @@ pub struct Tokenizer {
|
||||
added_tokens: HashMap<AddedToken, u32>,
|
||||
added_tokens_r: HashMap<u32, AddedToken>,
|
||||
split_re: Option<regex::Regex>,
|
||||
|
||||
// General processing parameters
|
||||
trunc: Option<TruncationParams>,
|
||||
padding: Option<PaddingParams>,
|
||||
}
|
||||
|
||||
impl Tokenizer {
|
||||
@ -139,11 +147,13 @@ impl Tokenizer {
|
||||
added_tokens: HashMap::new(),
|
||||
added_tokens_r: HashMap::new(),
|
||||
split_re: None,
|
||||
trunc: None,
|
||||
padding: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the normalizers
|
||||
pub fn with_normalizers(&mut self, normalizer: Box<dyn Normalizer + Sync>) -> &Self {
|
||||
/// Set the normalizer
|
||||
pub fn with_normalizer(&mut self, normalizer: Box<dyn Normalizer + Sync>) -> &Self {
|
||||
self.normalizer = Some(normalizer);
|
||||
self
|
||||
}
|
||||
@ -172,6 +182,18 @@ impl Tokenizer {
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the truncation parameters
|
||||
pub fn with_truncation(&mut self, trunc: Option<TruncationParams>) -> &Self {
|
||||
self.trunc = trunc;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the padding strategy
|
||||
pub fn with_padding(&mut self, padding: Option<PaddingParams>) -> &Self {
|
||||
self.padding = padding;
|
||||
self
|
||||
}
|
||||
|
||||
/// Get the size of the vocabulary
|
||||
pub fn get_vocab_size(&self) -> usize {
|
||||
self.model.get_vocab_size()
|
||||
@ -283,10 +305,17 @@ impl Tokenizer {
|
||||
|
||||
/// Encode all the sentences in parallel, using multiple threads
|
||||
pub fn encode_batch(&self, inputs: Vec<EncodeInput>) -> Result<Vec<Encoding>> {
|
||||
inputs
|
||||
let encodings = inputs
|
||||
.into_par_iter()
|
||||
.map(|input| self.encode(input))
|
||||
.collect()
|
||||
.collect::<Result<Vec<Encoding>>>()?;
|
||||
|
||||
if let Some(params) = &self.padding {
|
||||
// We do the padding here to make sure we handle the batch padding
|
||||
pad_encodings(encodings, ¶ms)
|
||||
} else {
|
||||
Ok(encodings)
|
||||
}
|
||||
}
|
||||
|
||||
/// Decode the given ids, back to a String
|
||||
@ -376,22 +405,63 @@ impl Tokenizer {
|
||||
/// Post processing logic, handling the case where there is no PostProcessor set
|
||||
fn post_process(
|
||||
&self,
|
||||
mut encoding: Encoding,
|
||||
encoding: Encoding,
|
||||
pair_encoding: Option<Encoding>,
|
||||
) -> Result<Encoding> {
|
||||
if let Some(processor) = &self.post_processor {
|
||||
processor.process(encoding, pair_encoding)
|
||||
// 1. First we truncate if needed
|
||||
let (mut encoding, pair_encoding) = {
|
||||
if let Some(trunc) = &self.trunc {
|
||||
let n_added_tokens = if let Some(processor) = &self.post_processor {
|
||||
processor.added_tokens(&encoding, &pair_encoding)?
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
if n_added_tokens > 0 {
|
||||
let params = TruncationParams {
|
||||
max_length: trunc.max_length - n_added_tokens,
|
||||
..*trunc
|
||||
};
|
||||
truncate_encodings(encoding, pair_encoding, ¶ms)?
|
||||
} else {
|
||||
truncate_encodings(encoding, pair_encoding, &trunc)?
|
||||
}
|
||||
} else {
|
||||
(encoding, pair_encoding)
|
||||
}
|
||||
};
|
||||
|
||||
// 2. Then We post process
|
||||
let mut final_encoding = if let Some(processor) = &self.post_processor {
|
||||
processor.process(encoding, pair_encoding)?
|
||||
} else {
|
||||
match pair_encoding {
|
||||
None => Ok(encoding),
|
||||
None => encoding,
|
||||
Some(pair) => {
|
||||
encoding.merge_with(pair);
|
||||
Ok(encoding)
|
||||
encoding
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// 3. Then we pad if needed
|
||||
if let Some(params) = &self.padding {
|
||||
// We can only pad for a given size. If the Strategy is BatchLongest, it will be done
|
||||
// when we handle a batch
|
||||
if let PaddingStrategy::Fixed(size) = params.strategy {
|
||||
final_encoding.pad(
|
||||
size,
|
||||
params.pad_id,
|
||||
params.pad_type_id,
|
||||
¶ms.pad_token,
|
||||
¶ms.direction,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(final_encoding)
|
||||
}
|
||||
|
||||
/// Add the given tokens to the added vocabulary
|
||||
pub fn add_tokens(&mut self, tokens: &[AddedToken]) -> usize {
|
||||
let mut ignored = 0;
|
||||
|
@ -1,15 +1,35 @@
|
||||
use crate::tokenizer::{Encoding, Result};
|
||||
use crate::tokenizer::{Encoding, PaddingDirection, Result};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TruncationParams {
|
||||
pub max_length: usize,
|
||||
pub strategy: TruncationStrategy,
|
||||
pub stride: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PaddingParams {
|
||||
pub strategy: PaddingStrategy,
|
||||
pub direction: PaddingDirection,
|
||||
pub pad_id: u32,
|
||||
pub pad_type_id: u32,
|
||||
pub pad_token: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum PaddingStrategy {
|
||||
BatchLongest,
|
||||
Fixed(usize),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum Error {
|
||||
SequenceTooSmall,
|
||||
SecondSequenceNotProvided,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Error {
|
||||
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
match self {
|
||||
Error::SequenceTooSmall => write!(fmt, "Truncation error: Sequence is too small"),
|
||||
Error::SecondSequenceNotProvided => {
|
||||
write!(fmt, "Truncation error: Second sequence not provided")
|
||||
}
|
||||
@ -18,7 +38,7 @@ impl std::fmt::Display for Error {
|
||||
}
|
||||
impl std::error::Error for Error {}
|
||||
|
||||
#[derive(Clone, Copy, PartialEq)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub enum TruncationStrategy {
|
||||
LongestFirst,
|
||||
OnlyFirst,
|
||||
@ -28,19 +48,27 @@ pub enum TruncationStrategy {
|
||||
pub fn truncate_encodings(
|
||||
mut encoding: Encoding,
|
||||
mut pair_encoding: Option<Encoding>,
|
||||
to_remove: usize,
|
||||
strategy: TruncationStrategy,
|
||||
stride: usize,
|
||||
params: &TruncationParams,
|
||||
) -> Result<(Encoding, Option<Encoding>)> {
|
||||
if to_remove == 0 {
|
||||
if params.max_length == 0 {
|
||||
return Ok((encoding, pair_encoding));
|
||||
}
|
||||
|
||||
match strategy {
|
||||
match params.strategy {
|
||||
TruncationStrategy::LongestFirst => {
|
||||
let total_length = encoding.get_ids().len()
|
||||
+ pair_encoding
|
||||
.as_ref()
|
||||
.map(|e| e.get_ids().len())
|
||||
.unwrap_or(0);
|
||||
let to_remove = if total_length > params.max_length {
|
||||
total_length - params.max_length
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
let mut n_first = 0;
|
||||
let mut n_second = 0;
|
||||
|
||||
for _ in 0..to_remove {
|
||||
if pair_encoding.is_none()
|
||||
|| encoding.get_ids().len() > pair_encoding.as_ref().unwrap().get_ids().len()
|
||||
@ -51,13 +79,13 @@ pub fn truncate_encodings(
|
||||
}
|
||||
}
|
||||
|
||||
encoding.truncate(encoding.get_ids().len() - n_first, stride);
|
||||
encoding.truncate(encoding.get_ids().len() - n_first, params.stride);
|
||||
if let Some(encoding) = pair_encoding.as_mut() {
|
||||
encoding.truncate(encoding.get_ids().len() - n_second, stride);
|
||||
encoding.truncate(encoding.get_ids().len() - n_second, params.stride);
|
||||
}
|
||||
}
|
||||
TruncationStrategy::OnlyFirst | TruncationStrategy::OnlySecond => {
|
||||
let target = if strategy == TruncationStrategy::OnlyFirst {
|
||||
let target = if params.strategy == TruncationStrategy::OnlyFirst {
|
||||
Ok(&mut encoding)
|
||||
} else if let Some(encoding) = pair_encoding.as_mut() {
|
||||
Ok(encoding)
|
||||
@ -65,13 +93,37 @@ pub fn truncate_encodings(
|
||||
Err(Box::new(Error::SecondSequenceNotProvided))
|
||||
}?;
|
||||
|
||||
if target.get_ids().len() <= to_remove {
|
||||
return Err(Box::new(Error::SequenceTooSmall));
|
||||
if target.get_ids().len() > params.max_length {
|
||||
target.truncate(params.max_length, params.stride);
|
||||
}
|
||||
|
||||
target.truncate(target.get_ids().len() - to_remove, stride);
|
||||
}
|
||||
}
|
||||
|
||||
Ok((encoding, pair_encoding))
|
||||
}
|
||||
|
||||
pub fn pad_encodings(
|
||||
mut encodings: Vec<Encoding>,
|
||||
params: &PaddingParams,
|
||||
) -> Result<Vec<Encoding>> {
|
||||
if encodings.is_empty() {
|
||||
return Ok(encodings);
|
||||
}
|
||||
|
||||
let pad_length = match params.strategy {
|
||||
PaddingStrategy::Fixed(size) => size,
|
||||
PaddingStrategy::BatchLongest => encodings.iter().map(|e| e.get_ids().len()).max().unwrap(),
|
||||
};
|
||||
|
||||
for encoding in encodings.iter_mut() {
|
||||
encoding.pad(
|
||||
pad_length,
|
||||
params.pad_id,
|
||||
params.pad_type_id,
|
||||
¶ms.pad_token,
|
||||
¶ms.direction,
|
||||
);
|
||||
}
|
||||
|
||||
Ok(encodings)
|
||||
}
|
||||
|
Reference in New Issue
Block a user