Tokenizer handles Truncation and Padding

This commit is contained in:
Anthony MOI
2019-12-17 15:15:58 -05:00
parent 4c51399b00
commit 5729d3656a
4 changed files with 168 additions and 71 deletions

View File

@ -1,56 +1,30 @@
use crate::tokenizer::{Encoding, PostProcessor, Result};
use crate::utils::{truncate_encodings, TruncationStrategy};
pub struct BertProcessing {
max_len: usize,
trunc_strategy: TruncationStrategy,
trunc_stride: usize,
sep: (String, u32),
cls: (String, u32),
}
impl BertProcessing {
pub fn new(
max_len: usize,
trunc_strategy: TruncationStrategy,
trunc_stride: usize,
sep: (String, u32),
cls: (String, u32),
) -> Self {
BertProcessing {
max_len,
trunc_strategy,
trunc_stride,
sep,
cls,
}
pub fn new(sep: (String, u32), cls: (String, u32)) -> Self {
BertProcessing { sep, cls }
}
}
impl PostProcessor for BertProcessing {
fn process(&self, encoding: Encoding, pair_encoding: Option<Encoding>) -> Result<Encoding> {
let special_token_len = if pair_encoding.is_some() { 3 } else { 2 };
let total_len = encoding.get_ids().len()
+ pair_encoding
.as_ref()
.map(|e| e.get_ids().len())
.unwrap_or(0)
+ special_token_len;
let need_trunc = if total_len > self.max_len {
total_len - self.max_len
fn added_tokens(
&self,
_encoding: &Encoding,
pair_encoding: &Option<Encoding>,
) -> Result<usize> {
if pair_encoding.is_some() {
Ok(3)
} else {
0
};
let (mut encoding, pair_encoding) = truncate_encodings(
encoding,
pair_encoding,
need_trunc,
self.trunc_strategy,
self.trunc_stride,
)?;
Ok(2)
}
}
fn process(&self, mut encoding: Encoding, pair_encoding: Option<Encoding>) -> Result<Encoding> {
// Prepare ids
let ids = [&[self.cls.1], &encoding.get_ids()[..], &[self.sep.1]].concat();
let pair_ids = pair_encoding

View File

@ -1,4 +1,5 @@
/// The various possible padding directions
#[derive(Debug, Clone)]
pub enum PaddingDirection {
Left,
Right,
@ -155,8 +156,8 @@ impl Encoding {
pad_length: usize,
pad_id: u32,
pad_type_id: u32,
pad_token: String,
direction: PaddingDirection,
pad_token: &str,
direction: &PaddingDirection,
) {
if self.ids.len() > pad_length {
// We just do nothing if the wanted padding length is smaller than us
@ -166,13 +167,13 @@ impl Encoding {
let ids_pad = vec![pad_id; pad_length];
let type_ids_pad = vec![pad_type_id; pad_length];
let tokens_pad = vec![pad_token; pad_length];
let tokens_pad = vec![pad_token.to_owned(); pad_length];
let attention_pad = vec![0; pad_length];
let special_pad = vec![1; pad_length];
let offsets_pad = vec![(0, 0); pad_length];
match direction {
Left => {
PaddingDirection::Left => {
self.ids = [&ids_pad[..], &self.ids[..]].concat();
self.type_ids = [&type_ids_pad[..], &self.type_ids[..]].concat();
self.tokens = [&tokens_pad[..], &self.tokens[..]].concat();
@ -181,7 +182,7 @@ impl Encoding {
[&special_pad[..], &self.special_tokens_mask[..]].concat();
self.offsets = [&offsets_pad[..], &self.offsets[..]].concat();
}
Right => {
PaddingDirection::Right => {
self.ids = [&self.ids[..], &ids_pad[..]].concat();
self.type_ids = [&self.type_ids[..], &type_ids_pad[..]].concat();
self.tokens = [&self.tokens[..], &tokens_pad[..]].concat();

View File

@ -12,6 +12,9 @@
//! - PostProcessor: Takes care of the processing after tokenization. (Like truncating, padding,
//! ...)
//!
use crate::utils::{
pad_encodings, truncate_encodings, PaddingParams, PaddingStrategy, TruncationParams,
};
use rayon::prelude::*;
use std::{
collections::HashMap,
@ -20,7 +23,7 @@ use std::{
};
mod encoding;
pub use encoding::Encoding;
pub use encoding::*;
pub type Result<T> = std::result::Result<T, Box<dyn std::error::Error + Send + Sync>>;
@ -46,6 +49,7 @@ pub trait Model {
/// A PostProcessor has the responsibility to post process an encoded output of the Tokenizer.
/// Truncating, Padding, etc... are PostProcessor steps
pub trait PostProcessor {
fn added_tokens(&self, encoding: &Encoding, pair_encoding: &Option<Encoding>) -> Result<usize>;
fn process(&self, encoding: Encoding, pair_encoding: Option<Encoding>) -> Result<Encoding>;
}
@ -125,6 +129,10 @@ pub struct Tokenizer {
added_tokens: HashMap<AddedToken, u32>,
added_tokens_r: HashMap<u32, AddedToken>,
split_re: Option<regex::Regex>,
// General processing parameters
trunc: Option<TruncationParams>,
padding: Option<PaddingParams>,
}
impl Tokenizer {
@ -139,11 +147,13 @@ impl Tokenizer {
added_tokens: HashMap::new(),
added_tokens_r: HashMap::new(),
split_re: None,
trunc: None,
padding: None,
}
}
/// Set the normalizers
pub fn with_normalizers(&mut self, normalizer: Box<dyn Normalizer + Sync>) -> &Self {
/// Set the normalizer
pub fn with_normalizer(&mut self, normalizer: Box<dyn Normalizer + Sync>) -> &Self {
self.normalizer = Some(normalizer);
self
}
@ -172,6 +182,18 @@ impl Tokenizer {
self
}
/// Set the truncation parameters
pub fn with_truncation(&mut self, trunc: Option<TruncationParams>) -> &Self {
self.trunc = trunc;
self
}
/// Set the padding strategy
pub fn with_padding(&mut self, padding: Option<PaddingParams>) -> &Self {
self.padding = padding;
self
}
/// Get the size of the vocabulary
pub fn get_vocab_size(&self) -> usize {
self.model.get_vocab_size()
@ -283,10 +305,17 @@ impl Tokenizer {
/// Encode all the sentences in parallel, using multiple threads
pub fn encode_batch(&self, inputs: Vec<EncodeInput>) -> Result<Vec<Encoding>> {
inputs
let encodings = inputs
.into_par_iter()
.map(|input| self.encode(input))
.collect()
.collect::<Result<Vec<Encoding>>>()?;
if let Some(params) = &self.padding {
// We do the padding here to make sure we handle the batch padding
pad_encodings(encodings, &params)
} else {
Ok(encodings)
}
}
/// Decode the given ids, back to a String
@ -376,20 +405,61 @@ impl Tokenizer {
/// Post processing logic, handling the case where there is no PostProcessor set
fn post_process(
&self,
mut encoding: Encoding,
encoding: Encoding,
pair_encoding: Option<Encoding>,
) -> Result<Encoding> {
if let Some(processor) = &self.post_processor {
processor.process(encoding, pair_encoding)
// 1. First we truncate if needed
let (mut encoding, pair_encoding) = {
if let Some(trunc) = &self.trunc {
let n_added_tokens = if let Some(processor) = &self.post_processor {
processor.added_tokens(&encoding, &pair_encoding)?
} else {
0
};
if n_added_tokens > 0 {
let params = TruncationParams {
max_length: trunc.max_length - n_added_tokens,
..*trunc
};
truncate_encodings(encoding, pair_encoding, &params)?
} else {
truncate_encodings(encoding, pair_encoding, &trunc)?
}
} else {
(encoding, pair_encoding)
}
};
// 2. Then We post process
let mut final_encoding = if let Some(processor) = &self.post_processor {
processor.process(encoding, pair_encoding)?
} else {
match pair_encoding {
None => Ok(encoding),
None => encoding,
Some(pair) => {
encoding.merge_with(pair);
Ok(encoding)
encoding
}
}
};
// 3. Then we pad if needed
if let Some(params) = &self.padding {
// We can only pad for a given size. If the Strategy is BatchLongest, it will be done
// when we handle a batch
if let PaddingStrategy::Fixed(size) = params.strategy {
final_encoding.pad(
size,
params.pad_id,
params.pad_type_id,
&params.pad_token,
&params.direction,
);
}
}
Ok(final_encoding)
}
/// Add the given tokens to the added vocabulary

View File

@ -1,15 +1,35 @@
use crate::tokenizer::{Encoding, Result};
use crate::tokenizer::{Encoding, PaddingDirection, Result};
#[derive(Debug, Clone)]
pub struct TruncationParams {
pub max_length: usize,
pub strategy: TruncationStrategy,
pub stride: usize,
}
#[derive(Debug, Clone)]
pub struct PaddingParams {
pub strategy: PaddingStrategy,
pub direction: PaddingDirection,
pub pad_id: u32,
pub pad_type_id: u32,
pub pad_token: String,
}
#[derive(Debug, Clone)]
pub enum PaddingStrategy {
BatchLongest,
Fixed(usize),
}
#[derive(Debug)]
pub enum Error {
SequenceTooSmall,
SecondSequenceNotProvided,
}
impl std::fmt::Display for Error {
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
Error::SequenceTooSmall => write!(fmt, "Truncation error: Sequence is too small"),
Error::SecondSequenceNotProvided => {
write!(fmt, "Truncation error: Second sequence not provided")
}
@ -18,7 +38,7 @@ impl std::fmt::Display for Error {
}
impl std::error::Error for Error {}
#[derive(Clone, Copy, PartialEq)]
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum TruncationStrategy {
LongestFirst,
OnlyFirst,
@ -28,19 +48,27 @@ pub enum TruncationStrategy {
pub fn truncate_encodings(
mut encoding: Encoding,
mut pair_encoding: Option<Encoding>,
to_remove: usize,
strategy: TruncationStrategy,
stride: usize,
params: &TruncationParams,
) -> Result<(Encoding, Option<Encoding>)> {
if to_remove == 0 {
if params.max_length == 0 {
return Ok((encoding, pair_encoding));
}
match strategy {
match params.strategy {
TruncationStrategy::LongestFirst => {
let total_length = encoding.get_ids().len()
+ pair_encoding
.as_ref()
.map(|e| e.get_ids().len())
.unwrap_or(0);
let to_remove = if total_length > params.max_length {
total_length - params.max_length
} else {
0
};
let mut n_first = 0;
let mut n_second = 0;
for _ in 0..to_remove {
if pair_encoding.is_none()
|| encoding.get_ids().len() > pair_encoding.as_ref().unwrap().get_ids().len()
@ -51,13 +79,13 @@ pub fn truncate_encodings(
}
}
encoding.truncate(encoding.get_ids().len() - n_first, stride);
encoding.truncate(encoding.get_ids().len() - n_first, params.stride);
if let Some(encoding) = pair_encoding.as_mut() {
encoding.truncate(encoding.get_ids().len() - n_second, stride);
encoding.truncate(encoding.get_ids().len() - n_second, params.stride);
}
}
TruncationStrategy::OnlyFirst | TruncationStrategy::OnlySecond => {
let target = if strategy == TruncationStrategy::OnlyFirst {
let target = if params.strategy == TruncationStrategy::OnlyFirst {
Ok(&mut encoding)
} else if let Some(encoding) = pair_encoding.as_mut() {
Ok(encoding)
@ -65,13 +93,37 @@ pub fn truncate_encodings(
Err(Box::new(Error::SecondSequenceNotProvided))
}?;
if target.get_ids().len() <= to_remove {
return Err(Box::new(Error::SequenceTooSmall));
if target.get_ids().len() > params.max_length {
target.truncate(params.max_length, params.stride);
}
target.truncate(target.get_ids().len() - to_remove, stride);
}
}
Ok((encoding, pair_encoding))
}
pub fn pad_encodings(
mut encodings: Vec<Encoding>,
params: &PaddingParams,
) -> Result<Vec<Encoding>> {
if encodings.is_empty() {
return Ok(encodings);
}
let pad_length = match params.strategy {
PaddingStrategy::Fixed(size) => size,
PaddingStrategy::BatchLongest => encodings.iter().map(|e| e.get_ids().len()).max().unwrap(),
};
for encoding in encodings.iter_mut() {
encoding.pad(
pad_length,
params.pad_id,
params.pad_type_id,
&params.pad_token,
&params.direction,
);
}
Ok(encodings)
}