Tokenizer handles Truncation and Padding

This commit is contained in:
Anthony MOI
2019-12-17 15:15:58 -05:00
parent 4c51399b00
commit 5729d3656a
4 changed files with 168 additions and 71 deletions

View File

@ -1,56 +1,30 @@
use crate::tokenizer::{Encoding, PostProcessor, Result}; use crate::tokenizer::{Encoding, PostProcessor, Result};
use crate::utils::{truncate_encodings, TruncationStrategy};
pub struct BertProcessing { pub struct BertProcessing {
max_len: usize,
trunc_strategy: TruncationStrategy,
trunc_stride: usize,
sep: (String, u32), sep: (String, u32),
cls: (String, u32), cls: (String, u32),
} }
impl BertProcessing { impl BertProcessing {
pub fn new( pub fn new(sep: (String, u32), cls: (String, u32)) -> Self {
max_len: usize, BertProcessing { sep, cls }
trunc_strategy: TruncationStrategy,
trunc_stride: usize,
sep: (String, u32),
cls: (String, u32),
) -> Self {
BertProcessing {
max_len,
trunc_strategy,
trunc_stride,
sep,
cls,
}
} }
} }
impl PostProcessor for BertProcessing { impl PostProcessor for BertProcessing {
fn process(&self, encoding: Encoding, pair_encoding: Option<Encoding>) -> Result<Encoding> { fn added_tokens(
let special_token_len = if pair_encoding.is_some() { 3 } else { 2 }; &self,
let total_len = encoding.get_ids().len() _encoding: &Encoding,
+ pair_encoding pair_encoding: &Option<Encoding>,
.as_ref() ) -> Result<usize> {
.map(|e| e.get_ids().len()) if pair_encoding.is_some() {
.unwrap_or(0) Ok(3)
+ special_token_len;
let need_trunc = if total_len > self.max_len {
total_len - self.max_len
} else { } else {
0 Ok(2)
}; }
let (mut encoding, pair_encoding) = truncate_encodings( }
encoding,
pair_encoding,
need_trunc,
self.trunc_strategy,
self.trunc_stride,
)?;
fn process(&self, mut encoding: Encoding, pair_encoding: Option<Encoding>) -> Result<Encoding> {
// Prepare ids // Prepare ids
let ids = [&[self.cls.1], &encoding.get_ids()[..], &[self.sep.1]].concat(); let ids = [&[self.cls.1], &encoding.get_ids()[..], &[self.sep.1]].concat();
let pair_ids = pair_encoding let pair_ids = pair_encoding

View File

@ -1,4 +1,5 @@
/// The various possible padding directions /// The various possible padding directions
#[derive(Debug, Clone)]
pub enum PaddingDirection { pub enum PaddingDirection {
Left, Left,
Right, Right,
@ -155,8 +156,8 @@ impl Encoding {
pad_length: usize, pad_length: usize,
pad_id: u32, pad_id: u32,
pad_type_id: u32, pad_type_id: u32,
pad_token: String, pad_token: &str,
direction: PaddingDirection, direction: &PaddingDirection,
) { ) {
if self.ids.len() > pad_length { if self.ids.len() > pad_length {
// We just do nothing if the wanted padding length is smaller than us // We just do nothing if the wanted padding length is smaller than us
@ -166,13 +167,13 @@ impl Encoding {
let ids_pad = vec![pad_id; pad_length]; let ids_pad = vec![pad_id; pad_length];
let type_ids_pad = vec![pad_type_id; pad_length]; let type_ids_pad = vec![pad_type_id; pad_length];
let tokens_pad = vec![pad_token; pad_length]; let tokens_pad = vec![pad_token.to_owned(); pad_length];
let attention_pad = vec![0; pad_length]; let attention_pad = vec![0; pad_length];
let special_pad = vec![1; pad_length]; let special_pad = vec![1; pad_length];
let offsets_pad = vec![(0, 0); pad_length]; let offsets_pad = vec![(0, 0); pad_length];
match direction { match direction {
Left => { PaddingDirection::Left => {
self.ids = [&ids_pad[..], &self.ids[..]].concat(); self.ids = [&ids_pad[..], &self.ids[..]].concat();
self.type_ids = [&type_ids_pad[..], &self.type_ids[..]].concat(); self.type_ids = [&type_ids_pad[..], &self.type_ids[..]].concat();
self.tokens = [&tokens_pad[..], &self.tokens[..]].concat(); self.tokens = [&tokens_pad[..], &self.tokens[..]].concat();
@ -181,7 +182,7 @@ impl Encoding {
[&special_pad[..], &self.special_tokens_mask[..]].concat(); [&special_pad[..], &self.special_tokens_mask[..]].concat();
self.offsets = [&offsets_pad[..], &self.offsets[..]].concat(); self.offsets = [&offsets_pad[..], &self.offsets[..]].concat();
} }
Right => { PaddingDirection::Right => {
self.ids = [&self.ids[..], &ids_pad[..]].concat(); self.ids = [&self.ids[..], &ids_pad[..]].concat();
self.type_ids = [&self.type_ids[..], &type_ids_pad[..]].concat(); self.type_ids = [&self.type_ids[..], &type_ids_pad[..]].concat();
self.tokens = [&self.tokens[..], &tokens_pad[..]].concat(); self.tokens = [&self.tokens[..], &tokens_pad[..]].concat();

View File

@ -12,6 +12,9 @@
//! - PostProcessor: Takes care of the processing after tokenization. (Like truncating, padding, //! - PostProcessor: Takes care of the processing after tokenization. (Like truncating, padding,
//! ...) //! ...)
//! //!
use crate::utils::{
pad_encodings, truncate_encodings, PaddingParams, PaddingStrategy, TruncationParams,
};
use rayon::prelude::*; use rayon::prelude::*;
use std::{ use std::{
collections::HashMap, collections::HashMap,
@ -20,7 +23,7 @@ use std::{
}; };
mod encoding; mod encoding;
pub use encoding::Encoding; pub use encoding::*;
pub type Result<T> = std::result::Result<T, Box<dyn std::error::Error + Send + Sync>>; pub type Result<T> = std::result::Result<T, Box<dyn std::error::Error + Send + Sync>>;
@ -46,6 +49,7 @@ pub trait Model {
/// A PostProcessor has the responsibility to post process an encoded output of the Tokenizer. /// A PostProcessor has the responsibility to post process an encoded output of the Tokenizer.
/// Truncating, Padding, etc... are PostProcessor steps /// Truncating, Padding, etc... are PostProcessor steps
pub trait PostProcessor { pub trait PostProcessor {
fn added_tokens(&self, encoding: &Encoding, pair_encoding: &Option<Encoding>) -> Result<usize>;
fn process(&self, encoding: Encoding, pair_encoding: Option<Encoding>) -> Result<Encoding>; fn process(&self, encoding: Encoding, pair_encoding: Option<Encoding>) -> Result<Encoding>;
} }
@ -125,6 +129,10 @@ pub struct Tokenizer {
added_tokens: HashMap<AddedToken, u32>, added_tokens: HashMap<AddedToken, u32>,
added_tokens_r: HashMap<u32, AddedToken>, added_tokens_r: HashMap<u32, AddedToken>,
split_re: Option<regex::Regex>, split_re: Option<regex::Regex>,
// General processing parameters
trunc: Option<TruncationParams>,
padding: Option<PaddingParams>,
} }
impl Tokenizer { impl Tokenizer {
@ -139,11 +147,13 @@ impl Tokenizer {
added_tokens: HashMap::new(), added_tokens: HashMap::new(),
added_tokens_r: HashMap::new(), added_tokens_r: HashMap::new(),
split_re: None, split_re: None,
trunc: None,
padding: None,
} }
} }
/// Set the normalizers /// Set the normalizer
pub fn with_normalizers(&mut self, normalizer: Box<dyn Normalizer + Sync>) -> &Self { pub fn with_normalizer(&mut self, normalizer: Box<dyn Normalizer + Sync>) -> &Self {
self.normalizer = Some(normalizer); self.normalizer = Some(normalizer);
self self
} }
@ -172,6 +182,18 @@ impl Tokenizer {
self self
} }
/// Set the truncation parameters
pub fn with_truncation(&mut self, trunc: Option<TruncationParams>) -> &Self {
self.trunc = trunc;
self
}
/// Set the padding strategy
pub fn with_padding(&mut self, padding: Option<PaddingParams>) -> &Self {
self.padding = padding;
self
}
/// Get the size of the vocabulary /// Get the size of the vocabulary
pub fn get_vocab_size(&self) -> usize { pub fn get_vocab_size(&self) -> usize {
self.model.get_vocab_size() self.model.get_vocab_size()
@ -283,10 +305,17 @@ impl Tokenizer {
/// Encode all the sentences in parallel, using multiple threads /// Encode all the sentences in parallel, using multiple threads
pub fn encode_batch(&self, inputs: Vec<EncodeInput>) -> Result<Vec<Encoding>> { pub fn encode_batch(&self, inputs: Vec<EncodeInput>) -> Result<Vec<Encoding>> {
inputs let encodings = inputs
.into_par_iter() .into_par_iter()
.map(|input| self.encode(input)) .map(|input| self.encode(input))
.collect() .collect::<Result<Vec<Encoding>>>()?;
if let Some(params) = &self.padding {
// We do the padding here to make sure we handle the batch padding
pad_encodings(encodings, &params)
} else {
Ok(encodings)
}
} }
/// Decode the given ids, back to a String /// Decode the given ids, back to a String
@ -376,20 +405,61 @@ impl Tokenizer {
/// Post processing logic, handling the case where there is no PostProcessor set /// Post processing logic, handling the case where there is no PostProcessor set
fn post_process( fn post_process(
&self, &self,
mut encoding: Encoding, encoding: Encoding,
pair_encoding: Option<Encoding>, pair_encoding: Option<Encoding>,
) -> Result<Encoding> { ) -> Result<Encoding> {
if let Some(processor) = &self.post_processor { // 1. First we truncate if needed
processor.process(encoding, pair_encoding) let (mut encoding, pair_encoding) = {
if let Some(trunc) = &self.trunc {
let n_added_tokens = if let Some(processor) = &self.post_processor {
processor.added_tokens(&encoding, &pair_encoding)?
} else {
0
};
if n_added_tokens > 0 {
let params = TruncationParams {
max_length: trunc.max_length - n_added_tokens,
..*trunc
};
truncate_encodings(encoding, pair_encoding, &params)?
} else {
truncate_encodings(encoding, pair_encoding, &trunc)?
}
} else {
(encoding, pair_encoding)
}
};
// 2. Then We post process
let mut final_encoding = if let Some(processor) = &self.post_processor {
processor.process(encoding, pair_encoding)?
} else { } else {
match pair_encoding { match pair_encoding {
None => Ok(encoding), None => encoding,
Some(pair) => { Some(pair) => {
encoding.merge_with(pair); encoding.merge_with(pair);
Ok(encoding) encoding
} }
} }
};
// 3. Then we pad if needed
if let Some(params) = &self.padding {
// We can only pad for a given size. If the Strategy is BatchLongest, it will be done
// when we handle a batch
if let PaddingStrategy::Fixed(size) = params.strategy {
final_encoding.pad(
size,
params.pad_id,
params.pad_type_id,
&params.pad_token,
&params.direction,
);
}
} }
Ok(final_encoding)
} }
/// Add the given tokens to the added vocabulary /// Add the given tokens to the added vocabulary

View File

@ -1,15 +1,35 @@
use crate::tokenizer::{Encoding, Result}; use crate::tokenizer::{Encoding, PaddingDirection, Result};
#[derive(Debug, Clone)]
pub struct TruncationParams {
pub max_length: usize,
pub strategy: TruncationStrategy,
pub stride: usize,
}
#[derive(Debug, Clone)]
pub struct PaddingParams {
pub strategy: PaddingStrategy,
pub direction: PaddingDirection,
pub pad_id: u32,
pub pad_type_id: u32,
pub pad_token: String,
}
#[derive(Debug, Clone)]
pub enum PaddingStrategy {
BatchLongest,
Fixed(usize),
}
#[derive(Debug)] #[derive(Debug)]
pub enum Error { pub enum Error {
SequenceTooSmall,
SecondSequenceNotProvided, SecondSequenceNotProvided,
} }
impl std::fmt::Display for Error { impl std::fmt::Display for Error {
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
match self { match self {
Error::SequenceTooSmall => write!(fmt, "Truncation error: Sequence is too small"),
Error::SecondSequenceNotProvided => { Error::SecondSequenceNotProvided => {
write!(fmt, "Truncation error: Second sequence not provided") write!(fmt, "Truncation error: Second sequence not provided")
} }
@ -18,7 +38,7 @@ impl std::fmt::Display for Error {
} }
impl std::error::Error for Error {} impl std::error::Error for Error {}
#[derive(Clone, Copy, PartialEq)] #[derive(Debug, Clone, Copy, PartialEq)]
pub enum TruncationStrategy { pub enum TruncationStrategy {
LongestFirst, LongestFirst,
OnlyFirst, OnlyFirst,
@ -28,19 +48,27 @@ pub enum TruncationStrategy {
pub fn truncate_encodings( pub fn truncate_encodings(
mut encoding: Encoding, mut encoding: Encoding,
mut pair_encoding: Option<Encoding>, mut pair_encoding: Option<Encoding>,
to_remove: usize, params: &TruncationParams,
strategy: TruncationStrategy,
stride: usize,
) -> Result<(Encoding, Option<Encoding>)> { ) -> Result<(Encoding, Option<Encoding>)> {
if to_remove == 0 { if params.max_length == 0 {
return Ok((encoding, pair_encoding)); return Ok((encoding, pair_encoding));
} }
match strategy { match params.strategy {
TruncationStrategy::LongestFirst => { TruncationStrategy::LongestFirst => {
let total_length = encoding.get_ids().len()
+ pair_encoding
.as_ref()
.map(|e| e.get_ids().len())
.unwrap_or(0);
let to_remove = if total_length > params.max_length {
total_length - params.max_length
} else {
0
};
let mut n_first = 0; let mut n_first = 0;
let mut n_second = 0; let mut n_second = 0;
for _ in 0..to_remove { for _ in 0..to_remove {
if pair_encoding.is_none() if pair_encoding.is_none()
|| encoding.get_ids().len() > pair_encoding.as_ref().unwrap().get_ids().len() || encoding.get_ids().len() > pair_encoding.as_ref().unwrap().get_ids().len()
@ -51,13 +79,13 @@ pub fn truncate_encodings(
} }
} }
encoding.truncate(encoding.get_ids().len() - n_first, stride); encoding.truncate(encoding.get_ids().len() - n_first, params.stride);
if let Some(encoding) = pair_encoding.as_mut() { if let Some(encoding) = pair_encoding.as_mut() {
encoding.truncate(encoding.get_ids().len() - n_second, stride); encoding.truncate(encoding.get_ids().len() - n_second, params.stride);
} }
} }
TruncationStrategy::OnlyFirst | TruncationStrategy::OnlySecond => { TruncationStrategy::OnlyFirst | TruncationStrategy::OnlySecond => {
let target = if strategy == TruncationStrategy::OnlyFirst { let target = if params.strategy == TruncationStrategy::OnlyFirst {
Ok(&mut encoding) Ok(&mut encoding)
} else if let Some(encoding) = pair_encoding.as_mut() { } else if let Some(encoding) = pair_encoding.as_mut() {
Ok(encoding) Ok(encoding)
@ -65,13 +93,37 @@ pub fn truncate_encodings(
Err(Box::new(Error::SecondSequenceNotProvided)) Err(Box::new(Error::SecondSequenceNotProvided))
}?; }?;
if target.get_ids().len() <= to_remove { if target.get_ids().len() > params.max_length {
return Err(Box::new(Error::SequenceTooSmall)); target.truncate(params.max_length, params.stride);
} }
target.truncate(target.get_ids().len() - to_remove, stride);
} }
} }
Ok((encoding, pair_encoding)) Ok((encoding, pair_encoding))
} }
pub fn pad_encodings(
mut encodings: Vec<Encoding>,
params: &PaddingParams,
) -> Result<Vec<Encoding>> {
if encodings.is_empty() {
return Ok(encodings);
}
let pad_length = match params.strategy {
PaddingStrategy::Fixed(size) => size,
PaddingStrategy::BatchLongest => encodings.iter().map(|e| e.get_ids().len()).max().unwrap(),
};
for encoding in encodings.iter_mut() {
encoding.pad(
pad_length,
params.pad_id,
params.pad_type_id,
&params.pad_token,
&params.direction,
);
}
Ok(encodings)
}